Compare commits

...

3 Commits

Author SHA1 Message Date
xxy
bf3d340aa8 Merge origin/main — keep local version
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-05 18:48:31 +08:00
xxy
88793da902 Add .env file
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-05 18:48:04 +08:00
xxy
43f3e0b746 Initial commit
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-05 18:41:06 +08:00
24 changed files with 3598 additions and 2249 deletions

39
.env
View File

@ -1,35 +1,22 @@
# 报告生成独立服务环境配置(沿用原 eval_report 的外部服务) # 范文/模板解析子服务section_reference_block环境配置
# 与网关 eval_report、report_generation 同库
# 数据库MySQL与原项目同库) # 远程 MySQL章节内容入库目标与主项目同库)
DATABASE_URL=mysql+pymysql://root:Beidas0ft@192.168.4.177:3306/eval_report?charset=utf8mb4 DATABASE_URL=mysql+pymysql://root:Beidas0ft@192.168.4.177:3306/eval_report?charset=utf8mb4
DB_POOL_SIZE=15 DB_AUTO_CREATE_TABLES=true
DB_MAX_OVERFLOW=25
DB_POOL_TIMEOUT=60
DB_POOL_PRE_PING=true
# 文档存储根目录(附图提取按 DOC_PAT/{project_uuid}/<相对路径> 定位 .docx # 远程文档解析服务(上传文档 → Markdown
# 指向原项目的 docs 目录,保证附图能被找到 FILE_PARSE_API_URL=http://192.168.4.194:8000/convert
DOC_PAT=D:/Git-Project/eval_report/docs FILE_PARSE_FIELD_NAME=file
FILE_PARSE_ENGINE=auto
FILE_PARSE_HTTP_TIMEOUT_SEC=600
# Embedding 模型配置 # LLM为每个目录生成"声明";留空则使用确定性兜底模板)
EMBEDDING_API_KEY=sk-xtbpjqekezfbttasgbrczzskqenygmhqwpsobpiagwrlacfr
EMBEDDING_API_BASE=http://192.168.4.191:8001/v1
EMBEDDING_BATCH_MAX_DOCS=4
EMBEDDING_BATCH_MAX_CHARS=12000
EMBEDDING_MAX_CHUNK_CHARS=4000
# LLMOpenAI 兼容接口)
LLM_API_BASE=http://192.168.4.197:8086/v1 LLM_API_BASE=http://192.168.4.197:8086/v1
LLM_API_KEY=sk-99999999991234 LLM_API_KEY=sk-99999999991234
LLM_MODEL_NAME=Qwen3.6-27B LLM_MODEL_NAME=Qwen3.6-27B
LLM_HTTP_TIMEOUT_SEC=600 DECLARATION_USE_LLM=true
# 报告章节单次 chat 读超时(秒),长章节建议 600+
REPORT_LLM_HTTP_TIMEOUT_SEC=600
# Milvus 向量数据库 # 服务监听
MILVUS_DB_URL=http://192.168.4.191:19530
# 服务监听注意8099 已被网关 eval_report 占用,本子服务改用 8101
HOST=0.0.0.0 HOST=0.0.0.0
PORT=8101 PORT=8100
RELOAD=false

View File

@ -1,24 +1,21 @@
# 复制为 .env 后按实际环境填写。 # 复制为 .env 并按实际环境修改
# 数据库MySQL与原 eval_report 共用同一库) # 远程 MySQL章节内容入库目标
DATABASE_URL=mysql+pymysql://root:123456@127.0.0.1:3306/post_eval_report?charset=utf8mb4 DATABASE_URL=mysql+pymysql://root:Beidas0ft@192.168.4.177:3306/eval_report?charset=utf8mb4
DB_AUTO_CREATE_TABLES=true
# 文档存储根目录(附图提取按 DOC_PAT/{project_uuid}/<相对路径> 定位 .docx # 远程文档解析服务(上传文档 → Markdown
DOC_PAT=./docpath FILE_PARSE_API_URL=http://192.168.4.194:8000/convert
FILE_PARSE_FIELD_NAME=file
FILE_PARSE_ENGINE=auto
FILE_PARSE_HTTP_TIMEOUT_SEC=600
# LLMOpenAI 兼容接口) # LLM可选为每个目录生成"声明"。留空则使用确定性兜底模板。
LLM_API_BASE= LLM_API_BASE=http://192.168.4.197:8086/v1
LLM_API_KEY= LLM_API_KEY=sk-99999999991234
LLM_MODEL_NAME= LLM_MODEL_NAME=Qwen3.6-27B
# 报告章节单次 chat 读超时(秒),长章节建议 600+ DECLARATION_USE_LLM=true
REPORT_LLM_HTTP_TIMEOUT_SEC=600
# Embedding / Milvus向量检索证据
EMBEDDING_API_BASE=
EMBEDDING_API_KEY=
MILVUS_DB_URL=
# 服务监听 # 服务监听
HOST=0.0.0.0 HOST=0.0.0.0
PORT=8099 PORT=8100
RELOAD=false

28
.gitignore vendored
View File

@ -1,21 +1,13 @@
# Python-generated files # Python
__pycache__/ __pycache__/
*.py[oc] *.py[cod]
build/ .venv/
dist/ venv/
wheels/ *.egg-info/
*.egg-info
# Virtual environments # 环境与日志
.venv
# Environment / secrets
# .env — tracked intentionally
# Local artifacts
*.log
.DS_Store
comp/
docpath/
docs/
logs/ logs/
# IDE
.idea/
.vscode/

112
README.md
View File

@ -1,52 +1,98 @@
# 报告生成服务(独立抽取版) # 报告模板管理模块
`eval_report` 中抽取出的「后评价报告核心生成」链路,作为独立 FastAPI 服务运行。 上传一个文档,自动完成:
保留原有的证据装配(要素表 + Milvus 向量检索)、分章 LLM 生成、表格修复、报告合并与 SSE 流式进度,
连接与原项目相同的 MySQL / Milvus / LLM 服务。
## 范围 1. **远程解析**:调用 `http://192.168.4.194:8000/convert`(表单字段 `file` + `engine=auto`)将文档转换为 Markdown。
2. **抽取目录**:从 Markdown 中识别章节标题层级(目录)。
3. **生成声明**:为每个目录(章节)生成一段"章节声明"(撰写指引),存入模板。
4. **脱敏入库**:按标题拆分正文,对每个章节正文**脱敏**(去掉精确数字/金额/日期/百分比等),再按远程 MySQL `report_section_references` 表格式写入,得到可复用的模板化范文。
- 包含异步分章生成任务、进度查询、结果获取、SSE 实时事件、章节重试、任务取消。 解析、目录抽取、正文拆分逻辑参考 `eval_report/routers/template.py``routers/reference.py`
- 不含:鉴权、知识库 worker、模板/范文管理、Word(docx) 导出(这些仍在原 `eval_report` 中)。
## 目录结构 ## 目录结构
``` ```
report_generation/ config.py 全局配置DB / 解析服务 / LLM
main.py FastAPI 入口 main.py FastAPI 入口
config.py 配置DB / LLM / Embedding / Milvus / DOC_PAT database/ 连接、ORM 模型、建表
database/ SQLAlchemy 引擎、Session、ORM 模型、建表 models.py report_templates / report_template_sections / report_section_references
schemas/ Pydantic 模型 schemas/template.py 接口出入参
services/ 报告生成核心逻辑(含瘦身版 kb_service / docx_export_service / project_service services/
function/vector_store.py Milvus 向量库封装 file_parse_client.py 调用远程 /convert → Markdown
prompts/report_generation/ 提示词模板与章节合同 section_extractor.py 目录抽取 + 正文按标题拆分(共用同一遍历)
routers/report.py 报告生成 HTTP 端点 desensitize_service.py 章节正文脱敏(去精确数字等)
declaration_service.py 为每个目录生成"声明"LLM 可选 + 兜底模板)
llm_client.py OpenAI 兼容 Chat 客户端(可选)
routers/template.py 上传/列表/详情/删除
``` ```
## 快速开始 ## 配置
复制 `.env.example``.env` 并修改:
- `DATABASE_URL`:远程 MySQL章节内容入库目标
- `FILE_PARSE_API_URL`:远程文档解析服务(默认 `http://192.168.4.194:8000/convert`,文件字段 `FILE_PARSE_FIELD_NAME=file`,引擎 `FILE_PARSE_ENGINE=auto`)。
- `LLM_*`:可选。配置后用 LLM 生成更贴合的章节声明;留空则使用确定性兜底模板。
启动时会按需在远程库中创建本模块用到的三张表(`DB_AUTO_CREATE_TABLES=true`,已存在则跳过)。
## 运行
```bash ```bash
pip install -r requirements.txt pip install -r requirements.txt
cp .env.example .env # 按需填写 DATABASE_URL / LLM_* / EMBEDDING_* / MILVUS_DB_URL python main.py
uvicorn main:app --reload # 或
uvicorn main:app --host 0.0.0.0 --port 8100
``` ```
启动后访问 `http://127.0.0.1:8099/docs` 查看接口文档,`/health` 做健康检查。 打开 `http://localhost:8100/docs` 查看接口文档
## 主要接口(前缀 `/api/v1/write` ## 主要接口
| 方法 | 路径 | 说明 | | 方法 | 路径 | 说明 |
|------|------|------| | --- | --- | --- |
| GET | `/projects/{project_id}/generate-sections` | 预览模板章节提示词清单 | | POST | `/templates/upload` | 上传文档,解析为模板(目录+声明)并将章节内容入库 |
| POST | `/projects/{project_id}/generate-report-job` | 创建分章异步报告生成任务 | | GET | `/templates` | 模板列表 |
| GET | `/projects/{project_id}/generate-report-job/{job_id}` | 查询任务进度 | | GET | `/templates/{id}` | 模板详情(含目录与各章节声明) |
| GET | `/projects/{project_id}/generate-report-job/{job_id}/result` | 获取任务结果 | | DELETE | `/templates/{id}` | 删除模板 |
| GET | `/projects/{project_id}/generate-report-job/{job_id}/events` | 订阅实时事件SSE | | GET | `/health` | 健康检查 |
| POST | `/projects/{project_id}/generate-report-job/{job_id}/retry-chapter` | 重试指定章节 |
| POST | `/projects/{project_id}/generate-report-job/{job_id}/cancel` | 取消任务 |
## 依赖的外部数据 ### 上传示例
报告生成依赖原库中已有的项目数据:`projects``element_tables` / `element_cells`(要素表)、 ```bash
`report_templates` / `report_template_sections`(模板章节)、可选的 `report_section_references`(参考范文), curl -X POST "http://localhost:8100/templates/upload" \
以及 Milvus 中按项目 UUID 写入的文档向量。请确保新服务连接到已包含这些数据的 MySQL 与 Milvus。 -F "file=@/path/to/报告.docx"
```
返回包含:模板信息(每个目录的 `sectionDeclaration` 即声明)、入库章节数与各章节摘要。
## 日志
启动即初始化日志系统(`log/logger.py`),输出到控制台(强制 UTF-8避免 Windows 中文乱码)并写入 `logs/`
| 文件 | 内容 |
| --- | --- |
| `logs/app.log` | 全量日志(按大小轮转) |
| `logs/error.log` | WARNING 及以上 |
| `logs/upload.log` | 上传/解析/入库链路(`routers.template``services.*` |
- 每个 HTTP 请求会记录方法、路径、状态码、耗时,并在响应头返回 `X-Request-ID`
- uvicorn 的 access/error 日志也统一汇入上述文件。
- 可在 `.env` 调整:`LOG_LEVEL``LOG_DIR``LOG_TO_CONSOLE``LOG_MAX_BYTES``LOG_BACKUP_COUNT``LOG_HTTP_ACCESS`
## 数据落点
- `report_templates`:一条模板记录。
- `report_template_sections`:每个目录一条,`section_prompt` 字段存放该目录的**声明**。
- `report_section_references`:每个章节一条,存放该章节**脱敏后的正文内容**(与远程库现有格式一致)。
### 脱敏规则
`services/desensitize_service.py`
- 阿拉伯数字串(含小数/千分位/全角)→ 占位符(默认 `X``总投资10.5亿元``总投资X亿元``85.3%``X%``2020年3月``X年X月`
- 标题行(`#` 开头)整行保留,不动章节编号与标题。
- 行首枚举序号(`1``2` 等)保留,仅脱敏正文数字。
- 表格分隔行保留;数据格数字默认脱敏(`DESENSITIZE_MASK_TABLE_NUMBERS`)。
- 中文数字(一二三…)默认保留(多为序数/层级)。
- 可在 `.env` 调整:`DESENSITIZE_ENABLED``DESENSITIZE_PLACEHOLDER``DESENSITIZE_MASK_TABLE_NUMBERS`

View File

@ -1,64 +1,86 @@
""" """
config.py config.py
全局配置项可通过 .env 文件或环境变量覆盖 报告模板管理模块的全局配置可通过 .env 或环境变量覆盖
本项目为报告生成独立服务仅保留报告生成链路所需配置
数据库(MySQL) / LLM / Embedding / Milvus / 文档存储路径
""" """
from __future__ import annotations
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
# 应用基本信息 # 应用基本信息
APP_TITLE: str = "智能报告生成服务 API" APP_TITLE: str = "报告模板管理模块 API"
APP_VERSION: str = "0.1.0" APP_VERSION: str = "0.1.0"
APP_DESCRIPTION: str = "后评价报告分章异步生成后端服务(独立抽取版)" APP_DESCRIPTION: str = "上传文档 → 远程解析为 Markdown → 拆解目录/章节 → 入库远程 MySQL"
# 服务监听 # 服务监听
HOST: str = "0.0.0.0" HOST: str = "0.0.0.0"
PORT: int = 8099 PORT: int = 8100
RELOAD: bool = False RELOAD: bool = False
# CORS 允许的前端源(开发阶段放开,生产环境改为具体域名)
CORS_ORIGINS: list[str] = ["*"] CORS_ORIGINS: list[str] = ["*"]
# 数据库MySQL # 日志
DATABASE_URL: str = "mysql+pymysql://root:123456@127.0.0.1:3306/post_eval_report?charset=utf8mb4" LOG_LEVEL: str = "INFO" # DEBUG / INFO / WARNING / ERROR
DB_POOL_SIZE: int = 15 LOG_DIR: str = "logs" # 日志目录(相对启动目录或绝对路径)
DB_MAX_OVERFLOW: int = 25 LOG_TO_CONSOLE: bool = True # 是否同时输出到控制台
LOG_MAX_BYTES: int = 10 * 1024 * 1024 # 单文件最大字节数(轮转)
LOG_BACKUP_COUNT: int = 7 # 轮转保留份数
LOG_HTTP_ACCESS: bool = True # 是否记录每个 HTTP 请求
# 远程 MySQLmysql+pymysql://用户:密码@主机:端口/库名?charset=utf8mb4
DATABASE_URL: str = (
"mysql+pymysql://root:Beidas0ft@192.168.4.177:3306/eval_report?charset=utf8mb4"
)
DB_POOL_SIZE: int = 10
DB_MAX_OVERFLOW: int = 20
DB_POOL_TIMEOUT: int = 60 DB_POOL_TIMEOUT: int = 60
DB_POOL_PRE_PING: bool = True DB_POOL_PRE_PING: bool = True
# 启动时自动建表(仅创建本模块用到的表,已存在则跳过)
DB_AUTO_CREATE_TABLES: bool = True
# 文档存储根目录(附图提取时按 DOC_PAT/{project_uuid}/<相对路径> 定位 .docx # 远程文档解析服务:上传文件 → Markdown
DOC_PAT: str = "./docpath" FILE_PARSE_API_URL: str = "http://192.168.4.194:8000/convert"
FILE_PARSE_FIELD_NAME: str = "file"
# 解析引擎(随 multipart 一起提交的表单字段 engine
FILE_PARSE_ENGINE: str = "auto"
FILE_PARSE_HTTP_TIMEOUT_SEC: int = 600
FILE_PARSE_RETRY_COUNT: int = 3
FILE_PARSE_RETRY_BACKOFF_SEC: float = 15.0
# LLMOpenAI 兼容接口) # 章节正文:是否包含其下级小节内容(章/节聚合整棵子树正文,避免父章节正文为空)
SECTION_CONTENT_INCLUDE_SUBSECTIONS: bool = True
# 单章节正文入库字节上限MySQL TEXT 列上限 65535 字节,留余量防止截断到半个字符)
SECTION_CONTENT_MAX_BYTES: int = 60000
# 章节内容脱敏:入库前过滤精确数据(数字/金额/日期/百分比等)
DESENSITIZE_ENABLED: bool = True
DESENSITIZE_PLACEHOLDER: str = "X" # 数字脱敏后的占位符
# 是否把表格中的数字也脱敏(表格通常是精确数据,默认开启)
DESENSITIZE_MASK_TABLE_NUMBERS: bool = True
# LLM可选为每个目录生成"声明"。未配置时使用确定性兜底模板。
LLM_API_BASE: str = "" LLM_API_BASE: str = ""
LLM_API_KEY: str = "" LLM_API_KEY: str = ""
LLM_MODEL_NAME: str = "" LLM_MODEL_NAME: str = ""
LLM_HTTP_TIMEOUT_SEC: int = 120 LLM_HTTP_TIMEOUT_SEC: int = 120
LLM_CONNECT_TIMEOUT_SEC: int = 30 # 关闭思考模型的思维链输出vLLM/Qwen3 等chat_template_kwargs.enable_thinking=false
LLM_RETRY_COUNT: int = 3 # 既避免"思考过程"混入正文,又减少 token、降低截断与耗时。
LLM_RETRY_BACKOFF_SEC: float = 1.0 LLM_DISABLE_THINKING: bool = True
LLM_RETRY_BACKOFF_MAX_SEC: float = 12.0 # 是否调用 LLM 生成章节声明(关闭则始终使用兜底模板)
# 报告章节单次 chat 读超时。0 表示沿用 LLM_HTTP_TIMEOUT_SEC长章节建议 600+ DECLARATION_USE_LLM: bool = True
REPORT_LLM_HTTP_TIMEOUT_SEC: int = 600 # 上传模版时:用 LLM 匹配默认提示词 / 为无匹配章节生成提示词(复刻 eval_report
# 某章 LLM 仍失败时写入占位正文并继续后续章节,避免整份任务失败 TEMPLATE_UPLOAD_LLM_PROMPT_MAPPING: bool = True
REPORT_LLM_CONTINUE_ON_TIMEOUT: bool = True # LLM 提示词匹配并发:把未匹配章节分批并行调用,缩短整体耗时。
# 表格抽取延迟补抽(首轮失败后进入队列,按轮次延迟重试) # 多卡 A100 + 连续批处理vLLM/TGITP 或多副本)下,提高并发在飞请求数即可打满 GPU
LLM_TABLE_DELAY_RETRY_ROUNDS: int = 2 # - 调小 BATCH_SIZE请求更多更短确保批次数 ≥ 线程数,单请求尾延迟更低
LLM_TABLE_DELAY_RETRY_SEC: float = 8.0 # - 调大 MAX_WORKERS同时在飞的序列更多填满推理服务的批decode 吞吐接近峰值
LLM_TABLE_DELAY_RETRY_BACKOFF: float = 2.0 # - 调小 MAX_TOKENS每序列 KV 缓存预留更少,调度器可纳入更多并发序列
LLM_TABLE_DELAY_RETRY_MAX_SEC: float = 60.0 # 2×A100并发目标约 16较单卡的 8 翻倍BATCH_SIZE=2 保证常见规模也能跑满 16 路。
TEMPLATE_UPLOAD_LLM_BATCH_SIZE: int = 2 # 每批未匹配章节数量
# Embedding / Milvus向量检索证据 L2/L3 TEMPLATE_UPLOAD_LLM_MAX_WORKERS: int = 16 # 并行线程数上限(在飞请求数)
EMBEDDING_API_KEY: str = "" TEMPLATE_UPLOAD_LLM_MAX_TOKENS: int = 2048 # 单批最大输出 token
EMBEDDING_API_BASE: str = ""
EMBEDDING_BATCH_MAX_DOCS: int = 4
EMBEDDING_BATCH_MAX_CHARS: int = 12000
EMBEDDING_MAX_CHUNK_CHARS: int = 4000
MILVUS_DB_URL: str = ""
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file=".env",

View File

@ -1,27 +1,7 @@
""" """database package连接、模型与依赖注入。"""
database
数据库连接与 Session 管理
使用方式 from database.core import SessionLocal, engine
from database import get_db, SessionLocal, init_database
# 依赖注入FastAPI 路由)
@router.get("/items")
def list_items(db: Session = Depends(get_db)):
...
# 上下文管理器脚本、worker
with SessionLocal() as db:
...
"""
from database.core import engine, SessionLocal
from database.dependencies import get_db from database.dependencies import get_db
from database.init_db import init_database from database.init_db import init_database
__all__ = [ __all__ = ["engine", "SessionLocal", "get_db", "init_database"]
"engine",
"SessionLocal",
"get_db",
"init_database",
]

View File

@ -1,42 +1,33 @@
""" """
database/core.py database/core.py
SQLAlchemy 引擎与 Session 工厂 SQLAlchemy 引擎与 Session 工厂同步引擎连接远程 MySQL
- 同步引擎默认连接池QueuePool
- 后续可替换为 create_async_engine 实现异步
""" """
from __future__ import annotations
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import sessionmaker
from config import settings from config import settings
# -----------------------------------------------------------------------------
# 引擎配置
# -----------------------------------------------------------------------------
engine = create_engine( engine = create_engine(
settings.DATABASE_URL, settings.DATABASE_URL,
pool_size=settings.DB_POOL_SIZE, pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW, max_overflow=settings.DB_MAX_OVERFLOW,
pool_timeout=settings.DB_POOL_TIMEOUT, pool_timeout=settings.DB_POOL_TIMEOUT,
pool_pre_ping=settings.DB_POOL_PRE_PING, pool_pre_ping=settings.DB_POOL_PRE_PING,
pool_recycle=3600, # 1 小时回收空闲连接,避免 MySQL wait_timeout pool_recycle=3600,
connect_args={ connect_args={
"charset": "utf8mb4", "charset": "utf8mb4",
"use_unicode": True, "use_unicode": True,
"init_command": "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci", "init_command": "SET NAMES utf8mb4 COLLATE utf8mb4_unicode_ci",
}, },
echo=False, # 开发时可设为 True 打印 SQL echo=False,
) )
# -----------------------------------------------------------------------------
# Session 工厂
# -----------------------------------------------------------------------------
SessionLocal = sessionmaker( SessionLocal = sessionmaker(
bind=engine, bind=engine,
autocommit=False, autocommit=False,
autoflush=False, autoflush=False,
expire_on_commit=False, # 提交后对象仍可访问属性,便于返回响应 expire_on_commit=False,
) )

View File

@ -1,11 +1,11 @@
""" """
database/dependencies.py database/dependencies.py
FastAPI 依赖注入获取数据库 Session FastAPI 依赖注入获取数据库 Session
每个请求创建新 Session请求结束后自动关闭
""" """
from collections.abc import Generator from __future__ import annotations
from typing import Generator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -13,14 +13,6 @@ from database.core import SessionLocal
def get_db() -> Generator[Session, None, None]: def get_db() -> Generator[Session, None, None]:
"""
获取数据库 Session用于 FastAPI Depends()
用法
@router.get("/items")
def list_items(db: Session = Depends(get_db)):
...
"""
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db

View File

@ -1,764 +1,52 @@
""" """
database/init_db.py database/init_db.py
应用启动时初始化数据库表结构 按需建表仅创建本模块用到的三张表已存在则跳过checkfirst=True
执行 init.sql 中的 DDL使用 IF NOT EXISTS 保证幂等
""" """
import re from __future__ import annotations
from pathlib import Path
from sqlalchemy import text import logging
from sqlalchemy import inspect, text
from database.core import engine from database.core import engine
from database.models import Base
# DDL 与 init_db.py 同目录database/init.sql logger = logging.getLogger(__name__)
INIT_SQL_PATH = Path(__file__).resolve().parent / "init.sql"
INIT_TABLES = [
"projects",
"kb_directories",
"kb_documents",
"write_documents",
"doc_versions",
"element_tables",
"element_cells",
"extraction_results",
"element_extraction_results",
"element_conflicts",
"document_markdowns",
"document_chunks",
"report_templates",
"report_template_sections",
"report_generation_jobs",
"report_generation_chapters",
"departments",
"users",
"roles",
"permissions",
"role_permissions",
"user_roles",
"project_members",
"project_departments",
"fill_records",
"report_section_references",
]
_TARGET_TABLE_COLLATION = "utf8mb4_unicode_ci" def _ensure_reference_template_id_column() -> None:
"""为已存在的 report_section_references 表补充 template_id 字段(幂等)。
create_all(checkfirst=True) 只建缺失的表不会给已存在的表加列
def _existing_tables(conn) -> set[str]: 因此这里对历史表做一次轻量级 ALTER仅在缺列时执行
return {
row[0]
for row in conn.execute(
text(
"SELECT TABLE_NAME FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE()"
)
).fetchall()
}
def _table_collation(conn, table_name: str) -> str | None:
row = conn.execute(
text(
"SELECT TABLE_COLLATION FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = :table_name"
),
{"table_name": table_name},
).first()
return str(row[0]).strip() if row and row[0] else None
def _column_collation(conn, table_name: str, column_name: str) -> str | None:
row = conn.execute(
text(
"SELECT COLLATION_NAME FROM information_schema.COLUMNS "
"WHERE TABLE_SCHEMA = DATABASE() "
"AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name"
),
{"table_name": table_name, "column_name": column_name},
).first()
return str(row[0]).strip() if row and row[0] else None
def _normalize_projects_table(conn) -> None:
""" """
将历史库表/列统一为 utf8mb4_unicode_ci仅在实际不一致时执行 ALTER insp = inspect(engine)
if "report_section_references" not in insp.get_table_names():
切勿在每次启动时对已迁移库重复 CONVERT会长时间持有 metadata lock
阻塞所有对 projects 等表的读写并导致连接池耗尽
"""
existing = _existing_tables(conn)
tables_to_convert = [
name
for name in INIT_TABLES
if name in existing and _table_collation(conn, name) != _TARGET_TABLE_COLLATION
]
projects_uuid_needs_fix = (
"projects" in existing
and _column_collation(conn, "projects", "uuid") != _TARGET_TABLE_COLLATION
)
if not tables_to_convert and not projects_uuid_needs_fix:
return return
conn.execute(text("SET FOREIGN_KEY_CHECKS=0")) columns = {c["name"] for c in insp.get_columns("report_section_references")}
try: if "template_id" in columns:
for table_name in tables_to_convert: return
conn.execute(
text(
f"ALTER TABLE `{table_name}` "
"CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
)
)
if projects_uuid_needs_fix:
conn.execute(
text(
"ALTER TABLE projects "
"MODIFY uuid VARCHAR(32) CHARACTER SET utf8mb4 "
"COLLATE utf8mb4_unicode_ci NOT NULL"
)
)
conn.commit()
finally:
conn.execute(text("SET FOREIGN_KEY_CHECKS=1"))
with engine.begin() as conn:
def _split_sql_statements(content: str) -> list[str]: conn.execute(
""" text(
按分号拆分 SQL 语句忽略注释和空行 "ALTER TABLE report_section_references "
简单实现不处理字符串内的分号 "ADD COLUMN template_id VARCHAR(64) NULL"
""" )
# 移除单行注释 )
content = re.sub(r"--[^\n]*", "", content) conn.execute(
# 移除多行注释 text(
content = re.sub(r"/\*.*?\*/", "", content, flags=re.DOTALL) "ALTER TABLE report_section_references "
statements = [ "ADD INDEX ix_report_section_references_template_id (template_id)"
s.strip() )
for s in content.split(";") )
if s.strip() and not s.strip().startswith("--") logger.info("init_database: report_section_references.template_id 字段已补充")
]
return statements
def init_database() -> None: def init_database() -> None:
""" """在远程 MySQL 中创建本模块所需表(若不存在)。"""
执行 init.sql创建表结构并按需执行缺失字段迁移 Base.metadata.create_all(bind=engine, checkfirst=True)
_ensure_reference_template_id_column()
注意init.sql 里使用了 `CREATE TABLE IF NOT EXISTS`因此对已存在但缺列的旧库 logger.info("init_database: report_templates / report_template_sections / report_section_references 已就绪")
需要额外执行对应迁移脚本例如补齐 `kb_documents.factor`
"""
if not INIT_SQL_PATH.exists():
return
sql_text = INIT_SQL_PATH.read_text(encoding="utf-8")
statements = _split_sql_statements(sql_text)
with engine.connect() as conn:
for stmt in statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
# 表/索引已存在时忽略Duplicate key name、already exists
err_msg = str(e).lower()
# 历史库可能缺列,导致 CREATE INDEX 报 "Key column ... doesn't exist in table"。
# 这里先跳过,后续 migrate_extraction_results.sql 会补齐列并建索引。
if (
"already exists" in err_msg
or "duplicate" in err_msg
or ("key column" in err_msg and "doesn't exist in table" in err_msg)
or "error 1072" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# 仅在字符集未达标时执行 ALTER勿在每次 CREATE TABLE projects 后重复调用)
_normalize_projects_table(conn)
# ------------------------------------------------------------------
# Missing-column migrations (idempotent via "duplicate column" ignore)
# ------------------------------------------------------------------
factor_migrate_path = Path(__file__).resolve().parent / "migrate_kb_documents_factor.sql"
if factor_migrate_path.exists():
factor_sql_text = factor_migrate_path.read_text(encoding="utf-8")
factor_statements = _split_sql_statements(factor_sql_text)
with engine.connect() as conn:
for stmt in factor_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
# MySQL: Error 1060 "Duplicate column name 'factor'"
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# Missing tables/columns migrations (kb_directories + directory_id)
# ------------------------------------------------------------------
kb_dirs_migrate_path = Path(__file__).resolve().parent / "migrate_kb_directories.sql"
if kb_dirs_migrate_path.exists():
kb_dirs_sql_text = kb_dirs_migrate_path.read_text(encoding="utf-8")
kb_dirs_statements = _split_sql_statements(kb_dirs_sql_text)
with engine.connect() as conn:
for stmt in kb_dirs_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
# MySQL 常见“已存在/重复”错误:忽略以保证幂等
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "already exists" in err_msg
or "duplicate" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# Missing tables/columns migrations (extraction_results legacy schema)
# ------------------------------------------------------------------
extraction_migrate_path = Path(__file__).resolve().parent / "migrate_extraction_results.sql"
if extraction_migrate_path.exists():
extraction_sql_text = extraction_migrate_path.read_text(encoding="utf-8")
extraction_statements = _split_sql_statements(extraction_sql_text)
with engine.connect() as conn:
for stmt in extraction_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "already exists" in err_msg
or "duplicate" in err_msg
or "check that column/key exists" in err_msg
or "error 1072" in err_msg
or "doesn't exist" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# extraction_results移除历史 run_id 外键/字段(收敛到 batch_id
# ------------------------------------------------------------------
extraction_drop_run_id_path = (
Path(__file__).resolve().parent / "migrate_extraction_results_drop_run_id.sql"
)
if extraction_drop_run_id_path.exists():
extraction_drop_run_id_sql = extraction_drop_run_id_path.read_text(encoding="utf-8")
extraction_drop_run_id_statements = _split_sql_statements(extraction_drop_run_id_sql)
with engine.connect() as conn:
for stmt in extraction_drop_run_id_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"already exists" in err_msg
or "duplicate" in err_msg
or "doesn't exist" in err_msg
or "check that column/key exists" in err_msg
or "error 1091" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_conflicts补齐 table_id / cell_id旧库缺列导致 ORM 查询 500
# ------------------------------------------------------------------
ec_migrate_path = Path(__file__).resolve().parent / "migrate_element_conflicts.sql"
if ec_migrate_path.exists():
ec_sql_text = ec_migrate_path.read_text(encoding="utf-8")
ec_statements = _split_sql_statements(ec_sql_text)
with engine.connect() as conn:
for stmt in ec_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "already exists" in err_msg
or "errno 1060" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_conflicts兼容历史 project_element_id NOT NULL改为 NULL
# ------------------------------------------------------------------
ec_project_element_id_path = (
Path(__file__).resolve().parent / "migrate_element_conflicts_project_element_id_nullable.sql"
)
if ec_project_element_id_path.exists():
ec_peid_sql = ec_project_element_id_path.read_text(encoding="utf-8")
ec_peid_stmts = _split_sql_statements(ec_peid_sql)
with engine.connect() as conn:
for stmt in ec_peid_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
conn.rollback()
raise
# ------------------------------------------------------------------
# element_conflicts兼容历史 extraction_result_id NOT NULL改为 NULL
# ------------------------------------------------------------------
ec_extraction_result_id_path = (
Path(__file__).resolve().parent / "migrate_element_conflicts_extraction_result_id_nullable.sql"
)
if ec_extraction_result_id_path.exists():
ec_erid_sql = ec_extraction_result_id_path.read_text(encoding="utf-8")
ec_erid_stmts = _split_sql_statements(ec_erid_sql)
with engine.connect() as conn:
for stmt in ec_erid_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
conn.rollback()
raise
# ------------------------------------------------------------------
# extraction_resultsextracted_at / source_line_end
# ------------------------------------------------------------------
ext_time_path = Path(__file__).resolve().parent / "migrate_extraction_results_extracted_at.sql"
if ext_time_path.exists():
ext_sql = ext_time_path.read_text(encoding="utf-8")
ext_stmts = _split_sql_statements(ext_sql)
with engine.connect() as conn:
for stmt in ext_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "already exists" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_extraction_results要素抽取结果明细表若旧库缺表则补齐
# ------------------------------------------------------------------
el_ext_path = Path(__file__).resolve().parent / "migrate_element_extraction_results.sql"
if el_ext_path.exists():
el_ext_sql = el_ext_path.read_text(encoding="utf-8")
el_ext_stmts = _split_sql_statements(el_ext_sql)
with engine.connect() as conn:
for stmt in el_ext_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"already exists" in err_msg
or "duplicate" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# project_departments项目可见部门
# ------------------------------------------------------------------
proj_dept_path = Path(__file__).resolve().parent / "migrate_project_departments.sql"
if proj_dept_path.exists():
proj_dept_sql = proj_dept_path.read_text(encoding="utf-8")
proj_dept_stmts = _split_sql_statements(proj_dept_sql)
with engine.connect() as conn:
for stmt in proj_dept_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "already exists" in err_msg or "duplicate" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# report_template_sections章节输出合同section_output_contract
# ------------------------------------------------------------------
template_section_contract_path = (
Path(__file__).resolve().parent / "migrations" / "add_section_output_contract.sql"
)
if template_section_contract_path.exists():
template_section_contract_sql = template_section_contract_path.read_text(encoding="utf-8")
template_section_contract_stmts = _split_sql_statements(template_section_contract_sql)
with engine.connect() as conn:
for stmt in template_section_contract_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# report_section_references补齐 template_id按模板过滤参考范文
# ------------------------------------------------------------------
ref_template_id_path = (
Path(__file__).resolve().parent / "migrations" / "add_ref_template_id.sql"
)
if ref_template_id_path.exists():
ref_template_id_sql = ref_template_id_path.read_text(encoding="utf-8")
ref_template_id_stmts = _split_sql_statements(ref_template_id_sql)
with engine.connect() as conn:
for stmt in ref_template_id_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "duplicate key name" in err_msg
or "error 1061" in err_msg
or "already exists" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# users补齐 password_hash登录注册
# ------------------------------------------------------------------
users_pwd_path = Path(__file__).resolve().parent / "migrate_users_password_hash.sql"
if users_pwd_path.exists():
users_pwd_sql = users_pwd_path.read_text(encoding="utf-8")
users_pwd_stmts = _split_sql_statements(users_pwd_sql)
with engine.connect() as conn:
for stmt in users_pwd_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# departments补齐 description部门描述
# ------------------------------------------------------------------
dept_desc_path = Path(__file__).resolve().parent / "migrate_departments_description.sql"
if dept_desc_path.exists():
dept_desc_sql = dept_desc_path.read_text(encoding="utf-8")
dept_desc_stmts = _split_sql_statements(dept_desc_sql)
with engine.connect() as conn:
for stmt in dept_desc_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# projects用户删除的标准模版表不再被「同步模版」回补
# ------------------------------------------------------------------
proj_sup_path = Path(__file__).resolve().parent / "migrate_projects_sync_suppressed_tables.sql"
if proj_sup_path.exists():
proj_sup_sql = proj_sup_path.read_text(encoding="utf-8")
proj_sup_stmts = _split_sql_statements(proj_sup_sql)
with engine.connect() as conn:
for stmt in proj_sup_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_tables用户删行后不再被「同步模版」回补
# ------------------------------------------------------------------
et_sup_path = Path(__file__).resolve().parent / "migrate_element_tables_sync_suppressed.sql"
if et_sup_path.exists():
et_sup_sql = et_sup_path.read_text(encoding="utf-8")
et_sup_stmts = _split_sql_statements(et_sup_sql)
with engine.connect() as conn:
for stmt in et_sup_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_tables自定义行顺序加行插在选中行下刷新后仍保持
# ------------------------------------------------------------------
et_row_order_path = Path(__file__).resolve().parent / "migrate_element_tables_custom_row_order.sql"
if et_row_order_path.exists():
et_row_order_sql = et_row_order_path.read_text(encoding="utf-8")
et_row_order_stmts = _split_sql_statements(et_row_order_sql)
with engine.connect() as conn:
for stmt in et_row_order_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# kb_documentsstatus 语义 v20/2/3/4仅旧库仍有 status=1 且无 status=4时执行
# ------------------------------------------------------------------
status_v2_path = Path(__file__).resolve().parent / "migrate_kb_doc_status_v2.sql"
if status_v2_path.exists():
with engine.connect() as conn:
try:
probe = conn.execute(
text(
"""
SELECT
SUM(CASE WHEN status = 1 THEN 1 ELSE 0 END) AS s1,
SUM(CASE WHEN status = 4 THEN 1 ELSE 0 END) AS s4
FROM kb_documents
"""
)
).fetchone()
s1 = int((probe[0] if probe else 0) or 0)
s4 = int((probe[1] if probe else 0) or 0)
if s1 > 0 and s4 == 0:
status_v2_sql = status_v2_path.read_text(encoding="utf-8")
for stmt in _split_sql_statements(status_v2_sql):
stmt = stmt.strip()
if not stmt:
continue
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "doesn't exist" in err_msg and "kb_documents" in err_msg:
conn.rollback()
else:
conn.rollback()
raise
# ------------------------------------------------------------------
# kb_documentsstorage_rel_path + error_message
# ------------------------------------------------------------------
kb_storage_path = Path(__file__).resolve().parent / "migrate_kb_doc_storage_path.sql"
if kb_storage_path.exists():
kb_storage_sql = kb_storage_path.read_text(encoding="utf-8")
kb_storage_stmts = _split_sql_statements(kb_storage_sql)
with engine.connect() as conn:
for stmt in kb_storage_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# kb_documentscategory (文件分类)
# ------------------------------------------------------------------
category_migrate_path = Path(__file__).resolve().parent / "migrate_kb_documents_category.sql"
if category_migrate_path.exists():
category_sql_text = category_migrate_path.read_text(encoding="utf-8")
category_statements = _split_sql_statements(category_sql_text)
with engine.connect() as conn:
for stmt in category_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# kb_documents旧版分类 → 资料清单六大类
# ------------------------------------------------------------------
category_checklist_path = Path(__file__).resolve().parent / "migrate_kb_category_checklist.sql"
if category_checklist_path.exists():
checklist_sql_text = category_checklist_path.read_text(encoding="utf-8")
checklist_statements = _split_sql_statements(checklist_sql_text)
with engine.connect() as conn:
for stmt in checklist_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
conn.rollback()
# ------------------------------------------------------------------
# kb_documentsupload_filename上传/解压原始文件名)
# ------------------------------------------------------------------
upload_fn_path = Path(__file__).resolve().parent / "migrate_kb_documents_upload_filename.sql"
if upload_fn_path.exists():
upload_fn_sql = upload_fn_path.read_text(encoding="utf-8")
upload_fn_stmts = _split_sql_statements(upload_fn_sql)
with engine.connect() as conn:
for stmt in upload_fn_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_cellssource_type文档抽取 / 手工输入)
# ------------------------------------------------------------------
ec_source_type_path = Path(__file__).resolve().parent / "migrate_element_cells_source_type.sql"
if ec_source_type_path.exists():
ec_source_sql = ec_source_type_path.read_text(encoding="utf-8")
ec_source_stmts = _split_sql_statements(ec_source_sql)
with engine.connect() as conn:
for stmt in ec_source_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise

View File

@ -1,305 +1,24 @@
""" """
database/models.py database/models.py
SQLAlchemy ORM 模型 db.md / init.sql 对应 ORM 模型与远程 MySQLeval_report 现有表结构一致
- report_templates 模板
- report_template_sections 模板章节目录 + 声明
- report_section_references 章节参考范文章节内容入库目标
""" """
from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, JSON, String, Text, UniqueConstraint from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String, Text
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass pass
class Project(Base):
"""项目表(统一:知识库 + 撰写)"""
__tablename__ = "projects"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
uuid: Mapped[str] = mapped_column(
String(32),
unique=True,
nullable=False,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
doc_count: Mapped[int] = mapped_column(Integer, default=0)
eval_reports_count: Mapped[int] = mapped_column(Integer, default=0)
total_size: Mapped[str] = mapped_column(String(32), default="0 B")
tags: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(16), default="active")
color: Mapped[str] = mapped_column(String(16), default="#3b82f6")
sync_suppressed_table_names: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
kb_documents: Mapped[list["KbDocument"]] = relationship(
"KbDocument", back_populates="project", cascade="all, delete-orphan"
)
kb_directories: Mapped[list["KbDirectory"]] = relationship(
"KbDirectory", back_populates="project", cascade="all, delete-orphan"
)
write_documents: Mapped[list["WriteDocumentModel"]] = relationship(
"WriteDocumentModel", back_populates="project", cascade="all, delete-orphan"
)
class WriteDocumentModel(Base):
"""撰写文档表后评价报告。project_id 关联 projects.uuid"""
__tablename__ = "write_documents"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
title: Mapped[str] = mapped_column(String(255), nullable=False)
content: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
word_count: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
status: Mapped[str] = mapped_column(String(16), default="draft")
sort_order: Mapped[int] = mapped_column(Integer, default=0)
project: Mapped["Project"] = relationship("Project", back_populates="write_documents")
doc_versions: Mapped[list["DocumentVersion"]] = relationship(
"DocumentVersion", back_populates="document", cascade="all, delete-orphan"
)
class DocumentVersion(Base):
"""撰写文档版本表(对应 doc_versions"""
__tablename__ = "doc_versions"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
document_id: Mapped[str] = mapped_column(
ForeignKey("write_documents.id", ondelete="CASCADE"), nullable=False
)
version: Mapped[str] = mapped_column(String(32), nullable=False)
content: Mapped[str] = mapped_column(Text, nullable=False)
citation_payload: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
saved_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
author: Mapped[str] = mapped_column(String(64), nullable=False)
note: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
document: Mapped["WriteDocumentModel"] = relationship("WriteDocumentModel", back_populates="doc_versions")
class KbDocument(Base):
"""知识库文档表。project_id 关联 projects.uuid。status: 0=失败 2=排队中 3=处理中 4=可用"""
__tablename__ = "kb_documents"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
directory_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("kb_directories.id", ondelete="SET NULL"), nullable=True
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
upload_filename: Mapped[Optional[str]] = mapped_column(
String(255), nullable=True
) # 上传/解压时的原始文件名(含扩展名),与智能展示名 name 区分
size: Mapped[str] = mapped_column(String(32), nullable=False)
file_path: Mapped[Optional[str]] = mapped_column(String(512), nullable=True) # 仅目录路径,不含文件名
storage_rel_path: Mapped[Optional[str]] = mapped_column(
String(512), nullable=True
) # 项目内完整相对路径(含文件名),用于精确定位磁盘文件
word_count: Mapped[int] = mapped_column(Integer, default=0)
uploaded_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
status: Mapped[int] = mapped_column(Integer, default=2) # 0=失败 2=排队中 3=处理中 4=可用
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
category: Mapped[Optional[str]] = mapped_column(String(32), nullable=True, default=None)
project: Mapped["Project"] = relationship("Project", back_populates="kb_documents")
directory: Mapped[Optional["KbDirectory"]] = relationship("KbDirectory", back_populates="documents")
class KbDirectory(Base):
"""知识库目录表。project_id 关联 projects.uuidparent_id 形成目录树。"""
__tablename__ = "kb_directories"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
parent_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("kb_directories.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
full_path: Mapped[str] = mapped_column(String(1024), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
project: Mapped["Project"] = relationship("Project", back_populates="kb_directories")
documents: Mapped[list["KbDocument"]] = relationship("KbDocument", back_populates="directory")
class Task(Base):
"""独立后台任务表pdf2md 转换和 element-agent 要素抽取。"""
__tablename__ = "tasks"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project: Mapped[str] = mapped_column(String(64), nullable=False)
task_type: Mapped[int] = mapped_column(Integer, nullable=False)
file_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
file_path: Mapped[Optional[str]] = mapped_column(String(1024), nullable=True)
status: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
payload_json: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
result_path: Mapped[Optional[str]] = mapped_column(String(1024), nullable=True)
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
add_time: Mapped[datetime] = mapped_column(DateTime, nullable=False)
finish_time: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
class ElementTable(Base):
__tablename__ = "element_tables"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
table_type: Mapped[str] = mapped_column(String(32), nullable=False) # global/time
table_name: Mapped[str] = mapped_column(String(255), nullable=False)
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
is_time_dimension: Mapped[bool] = mapped_column(Boolean, default=False)
sort_order: Mapped[int] = mapped_column(Integer, default=0)
# JSON 数组字符串row_key 列表sync 模版时跳过为这些行补格子,避免用户删行后一同步又出现
sync_suppressed_row_keys: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
# JSON 数组:界面行键展示顺序(含用户加行)
custom_row_order: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class ElementCell(Base):
__tablename__ = "element_cells"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
table_id: Mapped[str] = mapped_column(ForeignKey("element_tables.id", ondelete="CASCADE"), nullable=False)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
row_key: Mapped[str] = mapped_column(String(255), nullable=False)
col_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source_document_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("kb_documents.id", ondelete="SET NULL"), nullable=True
)
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
source_quote: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
extraction_batch_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
extraction_model: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
source_type: Mapped[Optional[str]] = mapped_column(String(16), nullable=True) # extract | manual
conflict_status: Mapped[str] = mapped_column(String(16), default="none")
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class ExtractionResult(Base):
__tablename__ = "extraction_results"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
document_id: Mapped[str] = mapped_column(ForeignKey("kb_documents.id", ondelete="CASCADE"), nullable=False)
batch_id: Mapped[str] = mapped_column(String(64), nullable=False)
result_type: Mapped[str] = mapped_column(String(16), nullable=False) # table/element
table_type: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
table_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
item_key: Mapped[str] = mapped_column(String(255), nullable=False)
item_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
raw_payload: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
extracted_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True) # 抽取业务时间(旧库迁移前可为空)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class ElementExtractionResult(Base):
"""
要素抽取结果明细表面向细则章节/小节提示词 -> 项目材料抽取
字段对齐用户侧语义
- 表类型 -> table_type
- 年份 -> year
- 表名称 -> table_name
- 时间 -> extracted_at
- -> item_key
- -> item_value
- 来源文档ID -> source_document_id
- 来源行数 -> source_line_no / source_line_end
"""
__tablename__ = "element_extraction_results"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
table_type: Mapped[str] = mapped_column(String(32), nullable=False)
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
table_name: Mapped[str] = mapped_column(String(255), nullable=False)
extracted_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
item_key: Mapped[str] = mapped_column(String(255), nullable=False)
item_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source_document_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("kb_documents.id", ondelete="SET NULL"), nullable=True
)
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class ElementConflict(Base):
__tablename__ = "element_conflicts"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
table_id: Mapped[Optional[str]] = mapped_column(ForeignKey("element_tables.id", ondelete="SET NULL"), nullable=True)
cell_id: Mapped[Optional[str]] = mapped_column(ForeignKey("element_cells.id", ondelete="SET NULL"), nullable=True)
item_key: Mapped[str] = mapped_column(String(255), nullable=False)
old_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
new_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
selected_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source_document_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("kb_documents.id", ondelete="SET NULL"), nullable=True
)
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
status: Mapped[str] = mapped_column(String(16), default="pending")
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class DocumentMarkdown(Base):
__tablename__ = "document_markdowns"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
document_id: Mapped[str] = mapped_column(ForeignKey("kb_documents.id", ondelete="CASCADE"), nullable=False)
extracted_filename: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
markdown_content: Mapped[str] = mapped_column(Text, nullable=False)
content_hash: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class DocumentChunk(Base):
__tablename__ = "document_chunks"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
document_id: Mapped[str] = mapped_column(ForeignKey("kb_documents.id", ondelete="CASCADE"), nullable=False)
markdown_id: Mapped[Optional[str]] = mapped_column(ForeignKey("document_markdowns.id", ondelete="CASCADE"), nullable=True)
heading: Mapped[Optional[str]] = mapped_column(String(512), nullable=True)
chunk_text: Mapped[str] = mapped_column(Text, nullable=False)
chunk_index: Mapped[int] = mapped_column(Integer, default=0)
source_line_start: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
vector_id: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class ReportTemplate(Base): class ReportTemplate(Base):
__tablename__ = "report_templates" __tablename__ = "report_templates"
@ -316,9 +35,12 @@ class ReportTemplateSection(Base):
__tablename__ = "report_template_sections" __tablename__ = "report_template_sections"
id: Mapped[str] = mapped_column(String(64), primary_key=True) id: Mapped[str] = mapped_column(String(64), primary_key=True)
template_id: Mapped[str] = mapped_column(ForeignKey("report_templates.id", ondelete="CASCADE"), nullable=False) template_id: Mapped[str] = mapped_column(
ForeignKey("report_templates.id", ondelete="CASCADE"), nullable=False
)
section_key: Mapped[str] = mapped_column(String(64), nullable=False) section_key: Mapped[str] = mapped_column(String(64), nullable=False)
section_title: Mapped[str] = mapped_column(String(255), nullable=False) section_title: Mapped[str] = mapped_column(String(255), nullable=False)
# 本模块语义section_prompt 即为该目录生成的"声明"
section_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True) section_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
section_output_contract: Mapped[Optional[str]] = mapped_column(Text, nullable=True) section_output_contract: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
section_order: Mapped[int] = mapped_column(Integer, default=0) section_order: Mapped[int] = mapped_column(Integer, default=0)
@ -327,54 +49,15 @@ class ReportTemplateSection(Base):
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class ReportGenerationJob(Base):
__tablename__ = "report_generation_jobs"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
template_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("report_templates.id", ondelete="SET NULL"), nullable=True
)
status: Mapped[str] = mapped_column(String(16), default="pending") # pending/running/completed/failed
progress: Mapped[int] = mapped_column(Integer, default=0)
current_section_key: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
requested_by: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
options: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
snapshot: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
class ReportGenerationChapter(Base):
__tablename__ = "report_generation_chapters"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
job_id: Mapped[str] = mapped_column(
ForeignKey("report_generation_jobs.id", ondelete="CASCADE"), nullable=False
)
section_key: Mapped[str] = mapped_column(String(64), nullable=False)
section_title: Mapped[str] = mapped_column(String(255), nullable=False)
section_order: Mapped[int] = mapped_column(Integer, default=0)
status: Mapped[str] = mapped_column(String(16), default="pending") # pending/running/completed/failed
content: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
prompt_text: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
evidence_payload: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
validation_payload: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True)
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True)
class ReportSectionReference(Base): class ReportSectionReference(Base):
"""章节参考范文(独立于模板配置,用于报告生成时拼入 prompt""" """章节参考范文(章节内容入库目标,格式与远程 MySQL 现有表一致)。"""
__tablename__ = "report_section_references" __tablename__ = "report_section_references"
id: Mapped[str] = mapped_column(String(64), primary_key=True) id: Mapped[str] = mapped_column(String(64), primary_key=True)
# 关联模板(与 report_template_sections.template_id 一致);历史数据可能为空
template_id: Mapped[Optional[str]] = mapped_column( template_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("report_templates.id", ondelete="CASCADE"), nullable=True ForeignKey("report_templates.id", ondelete="CASCADE"), nullable=True, index=True
) )
source_file: Mapped[str] = mapped_column(String(255), nullable=False) source_file: Mapped[str] = mapped_column(String(255), nullable=False)
section_key: Mapped[str] = mapped_column(String(64), nullable=False) section_key: Mapped[str] = mapped_column(String(64), nullable=False)
@ -383,121 +66,3 @@ class ReportSectionReference(Base):
content: Mapped[str] = mapped_column(Text, nullable=False) content: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class Department(Base):
__tablename__ = "department"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
parent_id: Mapped[Optional[str]] = mapped_column(ForeignKey("departments.id", ondelete="SET NULL"), nullable=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
username: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
department_id: Mapped[Optional[str]] = mapped_column(ForeignKey("departments.id", ondelete="SET NULL"), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class Role(Base):
__tablename__ = "roles"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(64), nullable=False, unique=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class Permission(Base):
__tablename__ = "permissions"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
perm_key: Mapped[str] = mapped_column(String(128), nullable=False, unique=True)
perm_type: Mapped[str] = mapped_column(String(32), nullable=False) # menu/project
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class RolePermission(Base):
__tablename__ = "role_permissions"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
role_id: Mapped[str] = mapped_column(ForeignKey("roles.id", ondelete="CASCADE"), nullable=False)
permission_id: Mapped[str] = mapped_column(ForeignKey("permissions.id", ondelete="CASCADE"), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class UserRole(Base):
__tablename__ = "user_roles"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
user_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
role_id: Mapped[str] = mapped_column(ForeignKey("roles.id", ondelete="CASCADE"), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class ProjectMember(Base):
__tablename__ = "project_members"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
user_id: Mapped[str] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
role: Mapped[str] = mapped_column(String(32), default="editor")
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class ProjectDepartment(Base):
"""项目可见部门:绑定后,仅这些部门下的用户可访问(另有管理员与 project_members 例外)。"""
__tablename__ = "project_departments"
__table_args__ = (UniqueConstraint("project_id", "department_id", name="uq_project_department"),)
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
department_id: Mapped[str] = mapped_column(ForeignKey("departments.id", ondelete="CASCADE"), nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)
class FillRecord(Base):
"""回填记录:每次要素回填均留痕,支持证据追溯。"""
__tablename__ = "fill_records"
id: Mapped[str] = mapped_column(String(64), primary_key=True)
project_id: Mapped[str] = mapped_column(ForeignKey("projects.uuid", ondelete="CASCADE"), nullable=False)
cell_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("element_cells.id", ondelete="SET NULL"), nullable=True
)
table_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("element_tables.id", ondelete="SET NULL"), nullable=True
)
row_key: Mapped[str] = mapped_column(String(255), nullable=False)
col_key: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
year: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
filled_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
previous_value: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source_document_id: Mapped[Optional[str]] = mapped_column(
ForeignKey("kb_documents.id", ondelete="SET NULL"), nullable=True
)
source_document_name: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
source_line_no: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
source_line_end: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
source_quote: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
confidence: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
extraction_batch_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
extraction_model: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
fill_type: Mapped[str] = mapped_column(String(16), nullable=False, default="auto")
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False)

View File

@ -1,3 +1,5 @@
from .logger import configure_logging, get_logger """日志包:统一日志配置。"""
from log.logger import configure_logging, get_logger
__all__ = ["configure_logging", "get_logger"] __all__ = ["configure_logging", "get_logger"]

View File

@ -1,38 +1,40 @@
"""
log/logger.py
统一日志配置
- 控制台输出强制 UTF-8修复 Windows 控制台中文乱码
- logs/app.log 全量日志按大小轮转
- logs/error.log WARNING 及以上
- logs/upload.log 上传/解析/入库链路routers.templateservices.*
- 接管 uvicorn access/error 日志统一落盘
幂等重复调用只配置一次
"""
from __future__ import annotations from __future__ import annotations
import logging import logging
import sys
from logging.handlers import RotatingFileHandler from logging.handlers import RotatingFileHandler
from pathlib import Path from pathlib import Path
_CONFIGURED = False _CONFIGURED = False
_FILE_PROCESSING_PREFIXES = (
"worker.document_processing", _FORMAT = "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s"
"services.kb_service", _DATEFMT = "%Y-%m-%d %H:%M:%S"
"services.es_docs",
"services.element_llm_extract_service", # 上传/解析/入库链路相关的 logger 前缀(额外汇总到 upload.log
"routers.extract", _UPLOAD_PREFIXES = (
"function.documents", "routers.template",
"function.vector_store", "services.file_parse_client",
"repo.kb_documents", "services.section_extractor",
"routers.reference", "services.declaration_service",
"services.doc_convert_service",
"services.reference_service",
)
_DOCUMENT_GENERATION_PREFIXES = (
"services.write_service",
"services.report_generation_service",
"services.markdown_stream_service",
"services.llm_client", "services.llm_client",
"services.llm_runner",
"services.report_prompt_service",
"services.report_runtime_store",
)
# 生成全过程追踪:完整记录输入 prompt / 调用模型 / 模型输出
_GENERATION_TRACE_PREFIXES = (
"generation.trace",
) )
# 交由 root 统一处理的第三方/框架 logger
_DELEGATED_LOGGERS = ("uvicorn", "uvicorn.error", "uvicorn.access")
class _PrefixFilter(logging.Filter): class _PrefixFilter(logging.Filter):
def __init__(self, prefixes: tuple[str, ...]) -> None: def __init__(self, prefixes: tuple[str, ...]) -> None:
@ -41,144 +43,99 @@ class _PrefixFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool: def filter(self, record: logging.LogRecord) -> bool:
name = str(record.name or "") name = str(record.name or "")
return any(name == prefix or name.startswith(prefix + ".") for prefix in self.prefixes) return any(name == p or name.startswith(p + ".") for p in self.prefixes)
class _OtherFilter(logging.Filter): def _force_utf8_stream(stream):
def filter(self, record: logging.LogRecord) -> bool: """让控制台以 UTF-8 输出,避免 Windows GBK 控制台中文乱码。"""
name = str(record.name or "") reconfigure = getattr(stream, "reconfigure", None)
if any(name == prefix or name.startswith(prefix + ".") for prefix in _FILE_PROCESSING_PREFIXES): if callable(reconfigure):
return False try:
if any(name == prefix or name.startswith(prefix + ".") for prefix in _DOCUMENT_GENERATION_PREFIXES): reconfigure(encoding="utf-8", errors="replace")
return False except (ValueError, OSError):
if any(name == prefix or name.startswith(prefix + ".") for prefix in _GENERATION_TRACE_PREFIXES): pass
return False return stream
return True
def configure_logging( def configure_logging(
*, *,
log_dir: str | Path = "logs", log_dir: str | Path | None = None,
level: int = logging.INFO, level: str | int | None = None,
to_console: bool | None = None,
max_bytes: int | None = None,
backup_count: int | None = None,
) -> Path: ) -> Path:
"""配置全局日志。返回 app.log 路径。"""
global _CONFIGURED global _CONFIGURED
# 延迟导入,避免与 config 形成循环依赖问题
from config import settings
log_dir = log_dir if log_dir is not None else settings.LOG_DIR
level = level if level is not None else settings.LOG_LEVEL
to_console = to_console if to_console is not None else settings.LOG_TO_CONSOLE
max_bytes = max_bytes if max_bytes is not None else settings.LOG_MAX_BYTES
backup_count = backup_count if backup_count is not None else settings.LOG_BACKUP_COUNT
if isinstance(level, str):
level = getattr(logging, level.strip().upper(), logging.INFO)
target_dir = Path(log_dir).resolve() target_dir = Path(log_dir).resolve()
target_dir.mkdir(parents=True, exist_ok=True) target_dir.mkdir(parents=True, exist_ok=True)
other_log_path = target_dir / "other.log" app_log_path = target_dir / "app.log"
if _CONFIGURED: if _CONFIGURED:
return other_log_path return app_log_path
formatter = logging.Formatter( formatter = logging.Formatter(_FORMAT, datefmt=_DATEFMT)
"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
)
root_logger = logging.getLogger() def _rotating(name: str, *, backups: int | None = None) -> RotatingFileHandler:
root_logger.setLevel(level) h = RotatingFileHandler(
target_dir / name,
maxBytes=max_bytes,
backupCount=backups if backups is not None else backup_count,
encoding="utf-8",
)
h.setFormatter(formatter)
return h
file_processing_handler = RotatingFileHandler( # 全量日志
target_dir / "file_processing.log", app_handler = _rotating("app.log")
maxBytes=10 * 1024 * 1024, app_handler.setLevel(level)
backupCount=5,
encoding="utf-8",
)
file_processing_handler.setLevel(level)
file_processing_handler.setFormatter(formatter)
file_processing_handler.addFilter(_PrefixFilter(_FILE_PROCESSING_PREFIXES))
document_generation_handler = RotatingFileHandler( # 错误日志WARNING+
target_dir / "document_generation.log", error_handler = _rotating("error.log")
maxBytes=10 * 1024 * 1024, error_handler.setLevel(logging.WARNING)
backupCount=5,
encoding="utf-8",
)
document_generation_handler.setLevel(level)
document_generation_handler.setFormatter(formatter)
document_generation_handler.addFilter(_PrefixFilter(_DOCUMENT_GENERATION_PREFIXES))
other_handler = RotatingFileHandler( # 上传/解析链路日志
other_log_path, upload_handler = _rotating("upload.log", backups=max(backup_count, 10))
maxBytes=10 * 1024 * 1024, upload_handler.setLevel(level)
backupCount=5, upload_handler.addFilter(_PrefixFilter(_UPLOAD_PREFIXES))
encoding="utf-8",
)
other_handler.setLevel(level)
other_handler.setFormatter(formatter)
other_handler.addFilter(_OtherFilter())
# ── 要素抽取独立日志 ───────────────────────────────────────────── handlers: list[logging.Handler] = [app_handler, error_handler, upload_handler]
element_extract_handler = RotatingFileHandler(
target_dir / "element_extract.log",
maxBytes=10 * 1024 * 1024,
backupCount=10,
encoding="utf-8",
)
element_extract_handler.setLevel(level)
element_extract_handler.setFormatter(formatter)
element_extract_handler.addFilter(_PrefixFilter(("services.element_llm_extract_service", "routers.extract")))
# ── 文件上传/解析独立日志 ───────────────────────────────────────── if to_console:
file_upload_handler = RotatingFileHandler( console_handler = logging.StreamHandler(_force_utf8_stream(sys.stdout))
target_dir / "file_upload.log", console_handler.setLevel(level)
maxBytes=10 * 1024 * 1024, console_handler.setFormatter(formatter)
backupCount=10, handlers.append(console_handler)
encoding="utf-8",
)
file_upload_handler.setLevel(level)
file_upload_handler.setFormatter(formatter)
file_upload_handler.addFilter(_PrefixFilter(("routers.reference", "routers.template", "services.doc_convert_service", "services.reference_service", "services.kb_service", "routers.kb")))
# ── 报告生成独立日志 ────────────────────────────────────────────── root = logging.getLogger()
report_generation_handler = RotatingFileHandler( root.setLevel(level)
target_dir / "report_generation.log", root.handlers.clear()
maxBytes=10 * 1024 * 1024, for h in handlers:
backupCount=10, root.addHandler(h)
encoding="utf-8",
)
report_generation_handler.setLevel(level)
report_generation_handler.setFormatter(formatter)
report_generation_handler.addFilter(_PrefixFilter(("services.report_generation_service", "services.report_prompt_service", "services.report_runtime_store", "services.markdown_stream_service")))
# ── LLM 调用独立日志 ────────────────────────────────────────────── # 让 uvicorn 的日志走 root 统一落盘
llm_handler = RotatingFileHandler( for name in _DELEGATED_LOGGERS:
target_dir / "llm.log", lg = logging.getLogger(name)
maxBytes=10 * 1024 * 1024, lg.handlers.clear()
backupCount=10, lg.propagate = True
encoding="utf-8", lg.setLevel(level)
)
llm_handler.setLevel(level)
llm_handler.setFormatter(formatter)
llm_handler.addFilter(_PrefixFilter(("services.llm_client", "services.llm_runner")))
# ── 生成全过程追踪日志(输入 prompt / 模型 / 输出,单条可能较大)────────
generation_trace_handler = RotatingFileHandler(
target_dir / "generation_trace.log",
maxBytes=50 * 1024 * 1024,
backupCount=10,
encoding="utf-8",
)
generation_trace_handler.setLevel(level)
generation_trace_handler.setFormatter(formatter)
generation_trace_handler.addFilter(_PrefixFilter(_GENERATION_TRACE_PREFIXES))
stream_handler = logging.StreamHandler()
stream_handler.setLevel(level)
stream_handler.setFormatter(formatter)
root_logger.handlers.clear()
root_logger.addHandler(file_processing_handler)
root_logger.addHandler(document_generation_handler)
root_logger.addHandler(other_handler)
root_logger.addHandler(element_extract_handler)
root_logger.addHandler(file_upload_handler)
root_logger.addHandler(report_generation_handler)
root_logger.addHandler(llm_handler)
root_logger.addHandler(generation_trace_handler)
root_logger.addHandler(stream_handler)
_CONFIGURED = True _CONFIGURED = True
return other_log_path logging.getLogger(__name__).info("日志系统已初始化 | dir=%s | level=%s", target_dir, logging.getLevelName(level))
return app_log_path
def get_logger(name: str) -> logging.Logger: def get_logger(name: str) -> logging.Logger:

79
main.py
View File

@ -1,43 +1,50 @@
""" """
main.py main.py
报告生成独立服务 FastAPI 入口 报告模板管理模块 FastAPI 应用入口
启动方式 启动
uvicorn main:app --reload uvicorn main:app --host 0.0.0.0 --port 8100
python main.py
python main.py
""" """
import logging from __future__ import annotations
import time
import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import uvicorn from fastapi import FastAPI, Request
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from config import settings from config import settings
from database import engine, init_database from database import init_database
from log import configure_logging from log import configure_logging, get_logger
from routers import report from routers import template
# 在创建应用前完成日志配置
configure_logging() configure_logging()
_log = logging.getLogger(__name__) logger = get_logger("app")
access_logger = get_logger("app.access")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(_app: FastAPI):
"""应用启动与关闭时执行。""" logger.info("应用启动 | %s v%s", settings.APP_TITLE, settings.APP_VERSION)
init_database() if settings.DB_AUTO_CREATE_TABLES:
try:
init_database()
except Exception as e: # noqa: BLE001
logger.warning("启动建表失败(不影响已存在表的使用): %s", e)
yield yield
engine.dispose() logger.info("应用关闭")
app = FastAPI( app = FastAPI(
lifespan=lifespan,
title=settings.APP_TITLE, title=settings.APP_TITLE,
version=settings.APP_VERSION, version=settings.APP_VERSION,
description=settings.APP_DESCRIPTION, description=settings.APP_DESCRIPTION,
docs_url="/docs", lifespan=lifespan,
redoc_url="/redoc",
) )
app.add_middleware( app.add_middleware(
@ -48,19 +55,47 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(report.router, prefix="/api/v1")
@app.middleware("http")
async def log_requests(request: Request, call_next):
if not settings.LOG_HTTP_ACCESS:
return await call_next(request)
req_id = uuid.uuid4().hex[:8]
start = time.perf_counter()
client = request.client.host if request.client else "-"
access_logger.info("→ [%s] %s %s | client=%s", req_id, request.method, request.url.path, client)
try:
response = await call_next(request)
except Exception:
cost = (time.perf_counter() - start) * 1000
access_logger.exception("✗ [%s] %s %s | %.1fms | 未处理异常", req_id, request.method, request.url.path, cost)
raise
cost = (time.perf_counter() - start) * 1000
access_logger.info(
"← [%s] %s %s | %s | %.1fms",
req_id, request.method, request.url.path, response.status_code, cost,
)
response.headers["X-Request-ID"] = req_id
return response
@app.get("/health", tags=["系统"], summary="健康检查") app.include_router(template.router)
def health_check():
"""确认服务存活,返回版本信息。"""
@app.get("/health", tags=["健康检查"])
def health() -> dict:
return {"status": "ok", "version": settings.APP_VERSION} return {"status": "ok", "version": settings.APP_VERSION}
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn
# log_config=None沿用本模块 configure_logging() 的配置,避免被 uvicorn 覆盖
uvicorn.run( uvicorn.run(
"main:app", "main:app",
host=settings.HOST, host=settings.HOST,
port=settings.PORT, port=settings.PORT,
reload=settings.RELOAD, reload=settings.RELOAD,
log_config=None,
) )

View File

@ -0,0 +1,150 @@
"""Default section prompt and example variables used by report templates."""
DEFAULT_SECTION_PROMPT = "按《炼油化工建设项目后评价报告编制细则(修订)》对应章节要求撰写,缺失信息标注“待补充”,禁止编造。"
SECTION_PROMPT_RULES: list[tuple[str, str]] = [
(
"1.1 项目基本情况",
"严格按以下字段顺序输出:项目名称、建设单位、建设地点、建设类型、起止时间、建设内容、建设投资、占地面积。"
"所有事实均须来自证据材料,缺失写“待补充”,禁止编造,禁止复刻示例原文。",
),
(
"1.2 项目决策要点",
"按“项目背景123+ 预期目标(规模/质量/效益)”撰写。"
"证据依据用于内部校验,不在报告正文显示“【证据依据:...】”标记。"
"背景每条先写 24 句书面语再视需要附表;表格须用 Markdown禁止粘贴未对齐的原始文本表。"
"预期目标须从证据归纳:证据中已有产能、产量(万吨/年、辛烷值、国VI、收入、利润等时不得三条目标全写待补充。"
"改扩建项目应补充原装置问题与改造动因。",
),
(
"1.3 项目实施情况",
"说明实际建成内容及与批复方案的差异,给出关键里程碑时间线(立项、批复、开工、中交、投产等)。",
),
(
"1.4 项目运行情况",
"说明投产以来装置运行负荷、产量及主要财务运行情况(营业收入、利润等),并与预期目标对照。",
),
(
"2.1.1 资源与原料评价",
"必须与《模版.doc》中本节结构一致先简述可研原料来源与实际一致性"
"再给出「原料数量及组成对比表」(列含:序号、原料名称、规格、可研/初设/实际各自的「数量(万吨)」「占比(%)」、备注,须有合计行);"
"再给出运行负荷与加工量等对比叙述;"
"再给出「原料性质对比表(醚后碳四)」(列:序号、名称、可研报告、初步设计、实际生产、备注;行至少含密度、硫含量、氮含量等,可按项目增删);"
"最后给出组成/性质变化分析及后评价判断。"
"表格一律使用 Markdown 表头+分隔行,禁止粘贴未对齐的纯文本表。"
"本节第一张主表表题须固定为「表1 原料数量及组成对比表」(与章节输出合同及要素管理默认表题一致,与正文节内表序一致);"
"禁止用安评/专篇中「表2.6-1 原料选择加氢工艺技术对比」等同号异题表替代上述两张模版主表;"
"本节只保留模版主表,不输出附录与“非模版主表”字样。"
"数据来自证据包,缺失填待补充,禁止编造。",
),
(
"2.1.2.1 产品方案评价",
"按“事实依据—评价判断—问题与建议”组织内容,但不要在正文中显示"
"“【事实依据】”“【评价判断】”“【问题与建议】”这三个标题标签;"
"以自然段或编号表达即可。缺失信息写“待补充”,禁止编造。",
),
("2", "对照前期工作细则,评价可研、前评估、初设及决策程序的合规性、完整性与合理性。"),
("3", "对照建设实施细则评价建设管理、招投标、设计、采购、施工、监理、质量、HSE与竣工验收。"),
("4", "对照生产运行细则,评价生产准备、联合试运、达标情况、工艺技术、设备与辅助系统运行效果。"),
("5", "对照投资与经济效益细则,评价投资控制、资金到位、经营效益、后评价测算及不确定性分析。"),
("6", "对照影响与持续性细则,评价环境、安全、科技、社会影响及资源、产品、技术经济竞争力持续性。"),
("7", "给出综合评价结论、成功度评价、主要经验、问题与建议,结论须与前文证据保持一致。"),
]
SECTION_EXAMPLE_RULES: list[tuple[str, str]] = [
(
"1",
"项目名称宁夏石化分公司16万吨/年烷基化装置建设项目独立后评价。"
"建设内容包括16万吨/年烷基化单元、1.5万吨/年废酸再生单元及配套公用工程。"
"批复可研估算22812万元批复初设概算25079万元竣工决算32486万元。"
"烷基化单元2018年11月投产废酸再生单元2020年11月投运。\n"
"---EXAMPLE---\n"
"项目运行情况投产后烷基化油收率保持在86%以上高于设计值81.28%"
"受原油加工负荷影响,阶段性加工负荷与可研预期存在偏差;"
"通过全厂优化运行后烷基化装置加工负荷保持在90%以上,"
"满足国VI汽油质量升级需要。",
),
(
"2",
"前期决策评价项目可研由具备甲级资质单位编制前评估提出57条意见并完成落实。"
"原料来源为MTBE装置醚后碳四来源与实际生产一致"
"产品全部调入汽油系统,由中石油西北销售公司统一销售,市场风险可控。"
"工艺路线采用中石油自主硫酸法烷基化技术并配套P&P湿法废酸再生技术。\n"
"---EXAMPLE---\n"
"前期工作结论:项目在可研、前评估、初设及决策程序方面总体规范,"
"对国VI质量升级任务支撑明确"
"后续应针对专利技术工程化经验不足、概算约束偏紧等问题,"
"在施工图阶段加强工程量校核与投资风险预控。",
),
(
"3",
"建设实施评价:项目采用“业主+监理+E+PC”管理模式。"
"单位工程质量合格率100%HSE总体受控。"
"但施工图设计进度与采购协同不足,废酸再生单元受工艺包与设备整改影响,"
"中交与投产明显滞后。\n"
"---EXAMPLE---\n"
"招采与设计变更情况共发生设计变更118份变更费用约5033万元"
"对进度与投资控制产生不利影响。"
"经验上应优先采用以设计为龙头的EPC协同模式"
"并提前锁定关键设备选材和高温腐蚀工况边界条件。",
),
(
"4",
"生产运行评价烷基化单元2018年11月一次投产成功"
"废酸再生单元经四次整改后于2020年11月投运。"
"运行期主要问题集中在原料杂质波动、局部腐蚀及部分指标偏离设计值,"
"通过优化烷烯比、补酸策略及上游分离精度逐步改进。\n"
"---EXAMPLE---\n"
"达标评价结论:烷基化单元在标定中出现辛烷值、硫含量、酸耗偏差,"
"废酸再生单元处理能力达标但能耗偏高。"
"总体上项目工艺可用、装置可稳态运行,"
"需持续优化操作和设备防腐管理以提升长周期绩效。",
),
(
"5",
"主要经济指标对比示例(宁夏石化项目):\n\n"
"表5-1 主要经济指标对比表\n"
"| 指标 | 单位 | 可研值 | 后评价值 | 差值 |\n"
"| --- | --- | --- | --- | --- |\n"
"| 报批总投资 | 万元 | 22812 | 32486 | +9674 |\n"
"| 年均税后利润 | 万元 | 20652 | 13283 | -7369 |\n"
"| 税后内部收益率 | % | 85.12 | 35.46 | -49.66 |\n"
"| 静态投资回收期 | 年 | 2.29 | 4.38 | +2.09 |\n"
"\n"
"示例结论:项目收益率虽较可研明显下降,但仍高于基准收益率,"
"基本实现效益目标;投资控制偏弱是主要短板。\n"
"---EXAMPLE---\n"
"5.2.2 投资水平分析正文参考勿输出为表5-2表5-2 仅用于 5.2.1 投资变动情况表):\n\n"
"同类烷基化装置单位工程费对标(撰写段落时参考,非模版表号):\n"
"| 项目 | 规模(万吨/年) | 工程费(万元) | 单位造价(元/吨) |\n"
"| --- | --- | --- | --- |\n"
"| 宁夏石化 | 16 | 15159 | 947 |\n"
"| 乌石化 | 20 | 21286 | 1064 |\n"
"| 锦州石化 | 25 | 23401 | 936 |\n"
"| 兰州石化 | 20 | 14377 | 719 |",
),
(
"6",
"影响评价示例:项目落实环评及安全“三同时”,废气、废水、噪声监测达标,"
"投产以来未发生重大安全事故,环境与安全风险总体可控。\n"
"---EXAMPLE---\n"
"持续性评价示例(宁夏石化项目):\n\n"
"表6-1 装置技术经济指标对比表\n"
"| 项目名称 | 技术来源 | 规模(万吨/年) | 物耗(Wt)% | 能耗(kgEo/t) | 产品质量 | 产品收率Wt% | 排名 |\n"
"| --- | --- | --- | --- | --- | --- | --- | --- |\n"
"| 宁夏石化烷基化 | 自主硫酸法烷基化 | 16 | 待补充 | 待补充 | 国VI调和组分 | 86以上 | 待补充 |\n"
"| 同类装置A | … | … | … | … | … | … | … |",
),
(
"7",
"综合结论示例宁夏石化16万吨/年烷基化项目完成了国VI汽油质量升级目标"
"生产运行总体平稳,效益指标虽低于可研预测但仍高于基准收益率。"
"项目综合评分8.62分,评级为“良”。\n"
"---EXAMPLE---\n"
"建议示例:"
"1持续优化自主技术工程化能力重点治理腐蚀与高温环节选材问题"
"2加强设计-采购-施工一体化协同,减少大额变更;"
"3围绕原料品质与上游协同运行进一步提升装置长期经济性。",
),
]

View File

@ -1,23 +1,8 @@
# Web 框架 fastapi>=0.115.6
fastapi uvicorn[standard]>=0.34.0
uvicorn[standard] python-multipart>=0.0.20
pydantic pydantic>=2.11
pydantic-settings pydantic-settings>=2.7.1
SQLAlchemy>=2.0.36
# 数据库MySQL PyMySQL>=1.1.1
sqlalchemy requests>=2.32.3
pymysql
cryptography
# HTTPLLM / Embedding 调用)
requests
# 附图提取(解析项目 .docx 内嵌图片)
python-docx
# 向量检索Milvus + Embeddings + BM25
langchain-core
langchain-text-splitters
langchain-openai
langchain-milvus
pymilvus

346
routers/template.py Normal file
View File

@ -0,0 +1,346 @@
"""
routers/template.py
报告模板管理上传文档 远程解析为 Markdown 抽取目录并为每个目录生成声明
创建模板目录 + 声明 按章节拆分正文并入库远程 MySQL
"""
from __future__ import annotations
import asyncio
import logging
import os
import re
import tempfile
import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
from sqlalchemy import or_
from sqlalchemy.orm import Session
from database import get_db
from database.models import (
ReportSectionReference,
ReportTemplate,
ReportTemplateSection,
)
from schemas.template import (
SectionReferenceItem,
TemplateItem,
TemplateSectionItem,
UploadTemplateResult,
)
from services.template_prompt_mapper import resolve_uploaded_template_prompts
from services.desensitize_service import count_masked_numbers, desensitize_content
from services.file_parse_client import parse_file_to_markdown
from services.section_extractor import (
clamp_text_bytes,
extract_sections_from_text,
normalize_section_key,
parse_section_order,
split_markdown_into_sections,
)
from config import settings
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/templates", tags=["报告模板管理"])
ALLOWED_SUFFIXES = {".doc", ".docx", ".pdf", ".txt", ".md", ".html", ".htm", ".rtf"}
def _clamp_title(value: str | None) -> str:
return str(value or "").strip()[:255]
def _build_description(source_file: str | None) -> str:
sf = str(source_file or "").strip()
base = "通过文件上传导入"
return f"{base}\n来源文件:{sf}" if sf else base
def _find_duplicates(db: Session, filename: str) -> dict:
"""按来源文件名查找已导入的模板与章节范文,用于重复检查。"""
template_ids = [
t.id
for t in db.query(ReportTemplate.id)
.filter(ReportTemplate.description.like(f"%来源文件:{filename}%"))
.all()
]
ref_count = (
db.query(ReportSectionReference)
.filter(ReportSectionReference.source_file == filename)
.count()
)
return {"template_ids": template_ids, "ref_count": ref_count}
def _delete_by_source(db: Session, filename: str, template_ids: list[str]) -> None:
"""删除指定来源文件已导入的模板(含章节)与章节范文。"""
for tid in template_ids:
db.query(ReportTemplateSection).filter(
ReportTemplateSection.template_id == tid
).delete(synchronize_session=False)
db.query(ReportTemplate).filter(ReportTemplate.id == tid).delete(
synchronize_session=False
)
db.query(ReportSectionReference).filter(
ReportSectionReference.source_file == filename
).delete(synchronize_session=False)
db.commit()
def _extract_source_file(description: str | None) -> str | None:
"""从模板描述中解析来源文件名(上传时写入 '来源文件xxx')。"""
m = re.search(r"来源文件\s*[:]\s*(.+)$", str(description or ""))
if not m:
return None
return (m.group(1) or "").strip() or None
def _serialize_template(db: Session, template_id: str) -> TemplateItem:
t = db.query(ReportTemplate).filter(ReportTemplate.id == template_id).first()
if not t:
raise HTTPException(status_code=404, detail="模板不存在")
sections = (
db.query(ReportTemplateSection)
.filter(ReportTemplateSection.template_id == t.id)
.order_by(ReportTemplateSection.section_order.asc())
.all()
)
src = _extract_source_file(t.description)
return TemplateItem(
id=t.id,
name=t.name,
description=t.description,
sourceFile=src,
createdAt=t.created_at.strftime("%Y-%m-%d %H:%M:%S") if t.created_at else None,
updatedAt=t.updated_at.strftime("%Y-%m-%d %H:%M:%S") if t.updated_at else None,
isDefault=t.is_default,
isActive=t.is_active,
sections=[
TemplateSectionItem(
id=s.id,
sectionKey=s.section_key,
sectionTitle=s.section_title,
sectionPrompt=s.section_prompt,
sectionOutputContract=s.section_output_contract,
sectionOrder=s.section_order,
examples=s.examples,
)
for s in sections
],
)
@router.post("/upload", response_model=UploadTemplateResult, summary="上传文档并解析为模板(目录+声明)与章节内容")
async def upload_template_route(
file: UploadFile = File(...),
force: bool = Query(False, description="为 true 时覆盖同名来源文件的已有模板与章节后重新导入"),
db: Session = Depends(get_db),
):
filename = (file.filename or "").strip()
suffix = os.path.splitext(filename)[1].lower()
if suffix not in ALLOWED_SUFFIXES:
raise HTTPException(
status_code=400,
detail=f"不支持的文件格式({suffix or '未知'});支持:{', '.join(sorted(ALLOWED_SUFFIXES))}",
)
# 0) 重复检查同名来源文件已导入则拒绝force=true 则先删除旧数据再重导)
dup = _find_duplicates(db, filename)
if dup["template_ids"] or dup["ref_count"]:
if not force:
raise HTTPException(
status_code=409,
detail=(
f"文件「{filename}」已导入(模板 {len(dup['template_ids'])} 个、"
f"章节 {dup['ref_count']} 条),不可重复入库;"
f"如需覆盖请使用 force=true 重新上传。"
),
)
logger.info(
"重复导入覆盖 | file=%s | 删除旧模板=%s | 旧章节=%s",
filename, len(dup["template_ids"]), dup["ref_count"],
)
_delete_by_source(db, filename, dup["template_ids"])
# 1) 落盘临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp_path = tmp.name
tmp.write(await file.read())
try:
# 2) 远程解析为 Markdown
try:
markdown = parse_file_to_markdown(tmp_path)
except Exception as e: # noqa: BLE001
raise HTTPException(status_code=400, detail=f"文档解析失败:{e}")
if not markdown or not markdown.strip():
raise HTTPException(status_code=400, detail="解析结果为空")
# 3) 抽取目录
sections = extract_sections_from_text(markdown)
if not sections:
raise HTTPException(status_code=400, detail="未从文档中识别到章节标题(目录)")
# 4) 按标题拆分正文,对每段正文脱敏(去精确数字等);跳过无正文的章节
raw_sections = split_markdown_into_sections(markdown)
max_bytes = int(getattr(settings, "SECTION_CONTENT_MAX_BYTES", 60000) or 60000)
ref_sections: list[dict] = []
masked_total = 0
skipped_empty = 0
truncated = 0
for s in raw_sections:
filtered = desensitize_content(s["content"])
if not filtered.strip():
skipped_empty += 1
continue
masked_total += max(count_masked_numbers(s["content"], filtered), 0)
clamped = clamp_text_bytes(filtered, max_bytes)
if clamped is not filtered and len(clamped) != len(filtered):
truncated += 1
ref_sections.append({**s, "content": clamped})
logger.info(
"解析结果 | md_len=%s | toc=%s | 入库章节=%s | 跳过空正文=%s | 截断超长=%s | 脱敏数字串=%s",
len(markdown), len(sections), len(ref_sections), skipped_empty, truncated, masked_total,
)
# 5) 复刻 eval_report将上传目录匹配默认模板得到每章节 提示词/输出合同/示例
# 放到工作线程执行:内部含并行 LLM 调用,避免阻塞事件循环(上传期间仍可并发处理其它请求)
resolved = await asyncio.to_thread(resolve_uploaded_template_prompts, sections)
logger.info(
"提示词匹配完成 | 章节=%s | 命中提示词=%s",
len(resolved), sum(1 for r in resolved if (r.get("sectionPrompt") or "").strip()),
)
now = datetime.now()
# 6) 创建模板(目录 + 提示词/输出合同/示例)
template = ReportTemplate(
id=uuid.uuid4().hex,
name=os.path.splitext(filename)[0] or "上传模板",
description=_build_description(filename),
is_default=False,
is_active=True,
created_at=now,
updated_at=now,
)
db.add(template)
db.flush()
for i, sec in enumerate(sections):
r = resolved[i] if i < len(resolved) else {}
db.add(
ReportTemplateSection(
id=uuid.uuid4().hex,
template_id=template.id,
section_key=normalize_section_key(sec["sectionKey"], sec["sectionTitle"]),
section_title=_clamp_title(sec["sectionTitle"]),
section_prompt=(r.get("sectionPrompt") or None),
section_output_contract=(r.get("sectionOutputContract") or None),
section_order=i,
examples="",
created_at=now,
updated_at=now,
)
)
# 7) 章节内容入库report_section_references 格式)
saved_refs: list[SectionReferenceItem] = []
for sec in ref_sections:
ref = ReportSectionReference(
id=uuid.uuid4().hex,
template_id=template.id,
source_file=filename,
section_key=sec["section_key"],
section_title=_clamp_title(sec["section_title"]),
section_order=parse_section_order(sec["section_key"]),
content=sec["content"],
created_at=now,
updated_at=now,
)
db.add(ref)
saved_refs.append(
SectionReferenceItem(
id=ref.id,
templateId=ref.template_id,
sourceFile=ref.source_file,
sectionKey=ref.section_key,
sectionTitle=ref.section_title,
sectionOrder=ref.section_order,
contentLength=len(ref.content or ""),
content=ref.content or "",
)
)
db.commit()
logger.info(
"模板上传完成 | file=%s | toc=%s | refs=%s",
filename, len(sections), len(saved_refs),
)
return UploadTemplateResult(
template=_serialize_template(db, template.id),
sourceFile=filename,
markdownLength=len(markdown),
totalSections=len(sections),
totalReferences=len(saved_refs),
references=saved_refs,
parseWarnings=[],
)
except HTTPException:
db.rollback()
raise
except Exception as e: # noqa: BLE001
db.rollback()
raise HTTPException(status_code=400, detail=f"模板创建失败:{e}")
finally:
try:
os.remove(tmp_path)
except OSError:
pass
@router.get("", response_model=list[TemplateItem], summary="获取模板列表")
def list_templates_route(db: Session = Depends(get_db)):
rows = (
db.query(ReportTemplate)
.order_by(ReportTemplate.created_at.desc())
.all()
)
return [_serialize_template(db, r.id) for r in rows]
@router.get("/{template_id}", response_model=TemplateItem, summary="获取模板详情")
def get_template_route(template_id: str, db: Session = Depends(get_db)):
return _serialize_template(db, template_id)
@router.delete("/{template_id}", status_code=204, summary="删除模板(含其来源文件的章节范文)")
def delete_template_route(template_id: str, db: Session = Depends(get_db)):
t = db.query(ReportTemplate).filter(ReportTemplate.id == template_id).first()
if not t:
raise HTTPException(status_code=404, detail="模板不存在")
source_file = _extract_source_file(t.description)
db.query(ReportTemplateSection).filter(
ReportTemplateSection.template_id == t.id
).delete(synchronize_session=False)
db.delete(t)
# 同步删除该模板在 report_section_references 中的章节内容:
# 优先按 template_id 精确删除;同时按 source_file 兜底清理历史数据template_id 为空的旧记录)
conditions = [ReportSectionReference.template_id == t.id]
if source_file:
conditions.append(ReportSectionReference.source_file == source_file)
ref_deleted = (
db.query(ReportSectionReference)
.filter(or_(*conditions))
.delete(synchronize_session=False)
)
db.commit()
logger.info(
"模板删除完成 | template_id=%s | source_file=%s | 删除章节范文=%s",
template_id, source_file, ref_deleted,
)

51
schemas/template.py Normal file
View File

@ -0,0 +1,51 @@
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel
class TemplateSectionItem(BaseModel):
id: str
sectionKey: str
sectionTitle: str
# 复刻 eval_report章节提示词 / 输出合同 / 示例
sectionPrompt: Optional[str] = None
sectionOutputContract: Optional[str] = None
sectionOrder: int = 0
examples: Optional[str] = None
class TemplateItem(BaseModel):
id: str
name: str
description: Optional[str] = None
sourceFile: Optional[str] = None
createdAt: Optional[str] = None
updatedAt: Optional[str] = None
isDefault: bool = False
isActive: bool = True
sections: List[TemplateSectionItem] = []
class SectionReferenceItem(BaseModel):
id: str
templateId: Optional[str] = None
sourceFile: str
sectionKey: str
sectionTitle: str
sectionOrder: int = 0
contentLength: int = 0
content: str = ""
class UploadTemplateResult(BaseModel):
"""上传解析结果:模板(目录 + 声明)+ 入库的章节内容。"""
template: TemplateItem
sourceFile: str
markdownLength: int
totalSections: int
totalReferences: int
references: List[SectionReferenceItem] = []
parseWarnings: List[str] = []

View File

@ -0,0 +1,126 @@
"""
services/declaration_service.py
为每个目录章节生成一个"声明"
声明一段说明该章节应写什么结构与约束的撰写指引存入
report_template_sections.section_prompt
优先用 LLM结合章节标题 + 该章节正文生成未配置或失败时
回退到确定性模板保证流程稳定可用
"""
from __future__ import annotations
import logging
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from config import settings
from services.llm_client import chat_completions_json, llm_configured
logger = logging.getLogger(__name__)
_SYSTEM_PROMPT = (
"你是报告模板专家。任务:阅读给定章节的范文(参考正文),总结这一章应该怎么写,"
"作为后续报告撰写该章节的写作指引。需提炼:①内容要点(写哪些事项);"
"②组织结构(应有的小节/条目顺序);③数据与口径要求(需引用的对比/指标/表格等);"
"④写作约束(先事实后评价、缺失写「待补充」、不得编造)。"
"严格要求:不要输出任何思考过程或解释;只输出 JSON 对象 {\"guide\": \"...\"}"
"guide 为 300 字以内的写作指引纯文本(不含 markdown 标题);"
"范文缺失或过短时,按章节标题给出通用写作指引。"
)
def _strip_number_prefix(title: str) -> str:
t = str(title or "").strip()
t = re.sub(r"^(?:\d+(?:\.\d+)*|[一二三四五六七八九十]+[、.])\s*", "", t).strip()
return t
def _fallback_declaration(section_title: str) -> str:
label = _strip_number_prefix(section_title) or "本章节"
return (
f"本章节为「{label}」。撰写时应紧扣标题主题,先陈述事实与数据,再给出分析与评价;"
f"结构需与标题保持一致,条理清晰、用语规范;"
f"所有结论须有依据,缺失信息写「待补充」,禁止编造。"
)
def _build_user_prompt(section_title: str, content: str) -> str:
body = (content or "").strip()
if len(body) > 2500:
body = body[:2500]
body_block = f"\n\n该章节范文(参考正文,节选):\n```\n{body}\n```" if body else ""
return (
f"章节标题:{section_title}{body_block}\n\n"
f"请根据上述范文,总结该章节应该怎么写,并只返回 JSON{{\"guide\": \"300字以内的写作指引\"}}"
)
def generate_declaration(section_title: str, content: str = "") -> str:
"""根据范文为单个章节生成"怎么写"的写作指引JSON 取 guide自动剔除思考过程"""
use_llm = bool(getattr(settings, "DECLARATION_USE_LLM", True)) and llm_configured()
if not use_llm:
return _fallback_declaration(section_title)
try:
data = chat_completions_json(
system_prompt=_SYSTEM_PROMPT,
user_prompt=_build_user_prompt(section_title, content),
temperature=0.2,
max_tokens=2048,
)
guide = str((data or {}).get("guide") or "").strip()
if guide:
return guide
except Exception as e: # noqa: BLE001 - 兜底,保证主流程不被 LLM 影响
logger.warning("生成章节声明失败,使用兜底模板 | title=%s | err=%s", section_title, e)
return _fallback_declaration(section_title)
def _content_for_section(s: dict, content_by_key: dict[str, str]) -> str:
"""目录键可能是 canonical 形式,优先用标题中的编号前缀去匹配正文。"""
title = str(s.get("sectionTitle") or "")
m = re.match(r"^(\d+(?:\.\d+)*)", title.strip())
num = m.group(1) if m else ""
return content_by_key.get(num, "") or content_by_key.get(str(s.get("sectionKey") or ""), "")
def generate_declarations(sections: list[dict], content_by_key: dict[str, str] | None = None) -> list[str]:
"""
为目录中每个章节并发生成"怎么写"的写作指引基于范文
sections: [{sectionKey, sectionTitle}, ...]
content_by_key: 章节编号/ -> 范文正文用于为指引提供上下文可选
每章一次 LLM 调用多线程并发以打满 GPULLM 为网络 I/O线程下真正并行
"""
content_by_key = content_by_key or {}
tasks = [(str(s.get("sectionTitle") or ""), _content_for_section(s, content_by_key)) for s in sections]
if not tasks:
return []
use_llm = bool(getattr(settings, "DECLARATION_USE_LLM", True)) and llm_configured()
if not use_llm:
return [_fallback_declaration(title) for title, _ in tasks]
max_workers = max(int(getattr(settings, "TEMPLATE_UPLOAD_LLM_MAX_WORKERS", 8) or 8), 1)
results: list[str] = [""] * len(tasks)
if len(tasks) == 1:
results[0] = generate_declaration(*tasks[0])
return results
workers = min(max_workers, len(tasks))
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_idx = {
executor.submit(generate_declaration, title, content): i
for i, (title, content) in enumerate(tasks)
}
for fut in as_completed(future_to_idx):
idx = future_to_idx[fut]
try:
results[idx] = fut.result()
except Exception as e: # noqa: BLE001
logger.warning("章节声明并发生成失败,使用兜底 | idx=%s | err=%s", idx, e)
results[idx] = _fallback_declaration(tasks[idx][0])
logger.info("章节声明生成 | 章节=%s | 线程=%s", len(tasks), workers)
return results

View File

@ -0,0 +1,80 @@
"""
services/desensitize_service.py
章节内容脱敏把范文正文中的"精确数据"过滤掉得到可复用的模板化内容
规则默认
- 阿拉伯数字串含小数千分位全角数字 占位符默认 "X"
"总投资10.5亿元" "总投资X亿元""2020年3月" "X年X月""85.3%" "X%"
- 标题行 # 开头)整行保留,不动章节编号/标题。
- 行首的列表/枚举序号 "1" "1." "2"保留仅脱敏正文中的数字
- 单位与符号万元/亿元/%// 保留仅去掉其中的精确数值
可通过 config 调整占位符是否脱敏表格数字是否启用
中文数字一二三通常用于序数/层级默认保留
"""
from __future__ import annotations
import logging
import re
from config import settings
logger = logging.getLogger(__name__)
# 阿拉伯数字(含全角)串,允许小数点/千分位分隔
_NUMBER_RE = re.compile(r"[0-9-]+(?:[.,][0-9-]+)*")
# 行首枚举序号1 / 1. / 2 / 2、 等(这些是结构标记,保留)
_LEADING_ENUM_RE = re.compile(r"^(\s*(?:[(]\s*[0-9-]+\s*[)]|[0-9-]+\s*[).、.]))")
_HEADING_RE = re.compile(r"^\s*#{1,6}\s")
_TABLE_ROW_RE = re.compile(r"^\s*\|.*\|\s*$")
_TABLE_SEP_RE = re.compile(r"^\s*\|?[\s:\-|]+\|?\s*$")
def _mask_numbers(segment: str, placeholder: str) -> str:
return _NUMBER_RE.sub(placeholder, segment)
def _desensitize_line(line: str, placeholder: str, mask_table_numbers: bool) -> str:
# 标题行整行保留(不动章节编号/标题)
if _HEADING_RE.match(line):
return line
# 表格行
if _TABLE_ROW_RE.match(line):
if _TABLE_SEP_RE.match(line): # 分隔行 |---|---|
return line
if not mask_table_numbers:
return line
return _mask_numbers(line, placeholder)
# 普通正文:保留行首枚举序号,仅脱敏其余部分
m = _LEADING_ENUM_RE.match(line)
if m:
prefix = m.group(1)
rest = line[len(prefix):]
return prefix + _mask_numbers(rest, placeholder)
return _mask_numbers(line, placeholder)
def desensitize_content(text: str) -> str:
"""对单个章节正文脱敏。未启用时原样返回。"""
if not text:
return text
if not bool(getattr(settings, "DESENSITIZE_ENABLED", True)):
return text
placeholder = str(getattr(settings, "DESENSITIZE_PLACEHOLDER", "X") or "X")
mask_table = bool(getattr(settings, "DESENSITIZE_MASK_TABLE_NUMBERS", True))
lines = text.splitlines()
out = [_desensitize_line(ln, placeholder, mask_table) for ln in lines]
return "\n".join(out)
def count_masked_numbers(original: str, filtered: str) -> int:
"""粗略统计脱敏掉的数字串数量(用于日志)。"""
return len(_NUMBER_RE.findall(original or "")) - len(_NUMBER_RE.findall(filtered or ""))

View File

@ -0,0 +1,194 @@
"""
services/file_parse_client.py
调用远程解析服务默认 http://192.168.4.194:8000/convert
上传文件multipart文件字段默认 "file"并附带 engine=auto 表单字段 返回 Markdown
响应解析JSON 中按 results / md_content / mdcontent / markdown / content
逐层提取若响应非 JSON 则整体作为 Markdown 返回
"""
from __future__ import annotations
import json
import logging
import mimetypes
import time
import uuid
from pathlib import Path
from typing import Any
from urllib import error as urlerror
from urllib import request as urlrequest
from config import settings
logger = logging.getLogger(__name__)
MD_CONTENT_KEYS = ("md_content", "mdcontent", "markdown", "content")
class FileParseApiError(RuntimeError):
def __init__(self, message: str, *, status_code: int | None = None, api_url: str = "") -> None:
super().__init__(message)
self.status_code = status_code
self.api_url = api_url
def _build_multipart_body(
file_path: Path,
field_name: str,
extra_fields: dict[str, str] | None = None,
) -> tuple[bytes, str]:
boundary = uuid.uuid4().hex
mime_type = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
file_bytes = file_path.read_bytes()
parts: list[bytes] = []
# 普通表单字段(如 engine=auto
for key, value in (extra_fields or {}).items():
if value is None or str(value).strip() == "":
continue
parts.append(f"--{boundary}\r\n".encode("utf-8"))
parts.append(
f'Content-Disposition: form-data; name="{key}"\r\n\r\n'.encode("utf-8")
)
parts.append(f"{value}\r\n".encode("utf-8"))
# 文件字段
parts.append(f"--{boundary}\r\n".encode("utf-8"))
parts.append(
(
f'Content-Disposition: form-data; name="{field_name}"; filename="{file_path.name}"\r\n'
f"Content-Type: {mime_type}\r\n\r\n"
).encode("utf-8")
)
parts.append(file_bytes)
parts.append(f"\r\n--{boundary}--\r\n".encode("utf-8"))
return b"".join(parts), boundary
def _extract_md_contents(payload: Any) -> list[str]:
if isinstance(payload, str):
return [payload]
if isinstance(payload, list):
out: list[str] = []
for item in payload:
out.extend(_extract_md_contents(item))
return out
if not isinstance(payload, dict):
return []
for key in MD_CONTENT_KEYS:
value = payload.get(key)
if isinstance(value, str):
return [value]
results = payload.get("results")
if results is not None:
return _extract_md_contents(results)
out = []
for value in payload.values():
out.extend(_extract_md_contents(value))
return out
def _response_to_markdown(text: str) -> str:
try:
payload = json.loads(text)
except json.JSONDecodeError:
# 非 JSON 直接当作 Markdown 返回
return text
contents = _extract_md_contents(payload)
if not contents:
raise ValueError("解析服务响应中未找到 md_content/markdown/content 字段")
return "\n\n".join(c.strip() for c in contents if c and c.strip())
def _request_once(
api_url: str,
file_path: Path,
field_name: str,
*,
timeout_sec: int,
extra_fields: dict[str, str] | None = None,
) -> str:
body, boundary = _build_multipart_body(file_path, field_name, extra_fields)
req = urlrequest.Request(
api_url,
data=body,
method="POST",
headers={"content-type": f"multipart/form-data; boundary={boundary}"},
)
try:
with urlrequest.urlopen(req, timeout=timeout_sec) as resp:
raw = resp.read()
encoding = resp.headers.get_content_charset() or "utf-8"
return raw.decode(encoding, errors="replace")
except urlerror.HTTPError as exc:
body_text = ""
try:
body_text = (exc.read() or b"").decode("utf-8", errors="replace")[:1000]
except Exception:
pass
raise FileParseApiError(
f"解析服务 HTTP {exc.code}{api_url}{body_text or exc.reason}",
status_code=int(exc.code or 0),
api_url=api_url,
) from exc
except urlerror.URLError as exc:
raise FileParseApiError(
f"无法连接解析服务({api_url}{exc.reason}",
status_code=0,
api_url=api_url,
) from exc
def parse_file_to_markdown(file_path: str | Path) -> str:
"""
将上传文件通过远程 file_parse 服务转换为 Markdown
失败时对 5xx 做有限重试
"""
path = Path(file_path)
if not path.is_file():
raise FileNotFoundError(f"文件不存在: {path}")
api_url = str(settings.FILE_PARSE_API_URL or "").strip()
if not api_url:
raise ValueError("FILE_PARSE_API_URL 未配置")
field_name = str(settings.FILE_PARSE_FIELD_NAME or "file").strip() or "file"
timeout_sec = max(int(settings.FILE_PARSE_HTTP_TIMEOUT_SEC or 600), 30)
retry_count = max(int(settings.FILE_PARSE_RETRY_COUNT or 1), 1)
backoff_sec = max(float(settings.FILE_PARSE_RETRY_BACKOFF_SEC or 1.0), 1.0)
retryable_status = {500, 502, 503, 504}
extra_fields: dict[str, str] = {}
engine = str(getattr(settings, "FILE_PARSE_ENGINE", "") or "").strip()
if engine:
extra_fields["engine"] = engine
last_error: Exception | None = None
for attempt in range(1, retry_count + 1):
try:
raw = _request_once(
api_url, path, field_name, timeout_sec=timeout_sec, extra_fields=extra_fields
)
markdown = _response_to_markdown(raw)
if not markdown.strip():
raise ValueError("解析服务返回的 Markdown 为空")
return markdown
except FileParseApiError as exc:
last_error = exc
status = int(exc.status_code or 0)
if attempt >= retry_count or status not in retryable_status:
raise
wait = backoff_sec * attempt
logger.warning(
"file_parse 重试 %s/%s status=%s wait=%ss file=%s",
attempt, retry_count, status, wait, path.name,
)
time.sleep(wait)
if last_error:
raise last_error
raise RuntimeError("file_parse 请求失败")

View File

@ -1,724 +1,118 @@
"""
services/llm_client.py
极简 OpenAI 兼容 Chat Completions 客户端仅用于生成章节声明可选
"""
from __future__ import annotations from __future__ import annotations
import json import json
import logging import logging
import random
import re import re
import time
import threading
from typing import Any, Optional
import requests import requests
from requests import RequestException
from requests.exceptions import ChunkedEncodingError
from config import settings from config import settings
_logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 生成全过程追踪:完整记录输入 prompt / 调用模型 / 模型输出,写入 logs/generation_trace.log
_trace_logger = logging.getLogger("generation.trace")
_LLM_MAX_CONCURRENCY = 5
_llm_slots = threading.BoundedSemaphore(_LLM_MAX_CONCURRENCY)
class _RetryableLLMError(RuntimeError): def llm_configured() -> bool:
"""用于标记可安全重试的 LLM 调用异常。""" return bool(
str(settings.LLM_API_BASE or "").strip()
and str(settings.LLM_API_KEY or "").strip()
and str(settings.LLM_MODEL_NAME or "").strip()
)
class _ContentFieldStreamExtractor: _THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
"""从流式 JSON 文本中增量提取 content 字段的已解码正文。"""
def __init__(self) -> None:
self._raw = ""
self._content_started = False
self._content_done = False
self._value_start = -1
self._consumed_pos = 0
def feed(self, chunk: str) -> tuple[str, bool]:
if not chunk:
return "", False
self._raw += chunk
emitted = ""
done_now = False
if not self._content_started:
marker = '"content"'
idx = self._raw.find(marker)
if idx == -1:
return "", False
colon = self._raw.find(":", idx + len(marker))
if colon == -1:
return "", False
quote = self._raw.find('"', colon + 1)
if quote == -1:
return "", False
self._content_started = True
self._value_start = quote + 1
self._consumed_pos = self._value_start
if self._content_started and not self._content_done:
emitted, consumed_pos, done_now = self._decode_partial_json_string(
self._raw,
self._consumed_pos,
)
self._consumed_pos = consumed_pos
if done_now:
self._content_done = True
return emitted, done_now
@staticmethod
def _decode_partial_json_string(src: str, start: int) -> tuple[str, int, bool]:
out: list[str] = []
i = start
n = len(src)
while i < n:
ch = src[i]
if ch == '"':
if i == 0 or src[i - 1] != "\\" or _ContentFieldStreamExtractor._is_escaped(src, i):
return "".join(out), i, True
if ch != "\\":
out.append(ch)
i += 1
continue
if i + 1 >= n:
break
esc = src[i + 1]
mapping = {
'"': '"',
"\\": "\\",
"/": "/",
"b": "\b",
"f": "\f",
"n": "\n",
"r": "\r",
"t": "\t",
}
if esc == "u":
if i + 6 > n:
break
hex_part = src[i + 2 : i + 6]
try:
out.append(chr(int(hex_part, 16)))
except Exception:
pass
i += 6
continue
if esc in mapping:
out.append(mapping[esc])
i += 2
continue
out.append(esc)
i += 2
return "".join(out), i, False
@staticmethod
def _is_escaped(src: str, quote_index: int) -> bool:
backslashes = 0
i = quote_index - 1
while i >= 0 and src[i] == "\\":
backslashes += 1
i -= 1
return backslashes % 2 == 0
def _format_exc_raw(e: Exception) -> str: def _strip_reasoning(text: str) -> str:
"""统一输出最直接的异常原文(类型 + repr""" """去掉思考模型的思维链:成对 <think>…</think>,以及截断/前导的残留标签。"""
return f"{type(e).__name__}: {e!r}" s = text or ""
s = _THINK_BLOCK_RE.sub("", s)
# 仅剩结束标签时,说明前面是未配对的思考段,取最后一个 </think> 之后的正文
if "</think>" in s:
s = s.rsplit("</think>", 1)[-1]
s = re.sub(r"</?think>", "", s, flags=re.IGNORECASE)
return s.strip()
def _chat_completions_stream_text( def chat_completion_text(
*, *,
api_base: str,
api_key: str,
model_name: str,
system_prompt: str, system_prompt: str,
user_prompt: str, user_prompt: str,
temperature: float, temperature: float = 0.2,
max_tokens: int, max_tokens: int = 512,
extra_payload: dict[str, Any], timeout_sec: int | None = None,
connect_timeout_sec: int,
read_timeout_sec: int = 300,
on_content_delta: Optional[callable] = None,
) -> str: ) -> str:
""" """调用 LLM 返回纯文本。失败抛出异常,由调用方决定是否兜底。"""
OpenAI-compat SSE 流式读取模型输出文本 base = str(settings.LLM_API_BASE or "").strip().rstrip("/")
- connect timeout 保留避免连接阶段长时间卡死 url = f"{base}/chat/completions"
- read timeout 防止流式读取无限挂起默认 300s headers = {
""" "Authorization": f"Bearer {settings.LLM_API_KEY}",
_logger.info( "Content-Type": "application/json",
"LLM 流式调用开始 | model=%s | temperature=%s | max_tokens=%s | timeout_connect=%s timeout_read=%s", }
model_name, temperature, max_tokens, connect_timeout_sec, read_timeout_sec, payload = {
) "model": settings.LLM_MODEL_NAME,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
}
# 关闭思考模型的思维链vLLM/Qwen3 等支持该扩展字段;不支持的服务会忽略)
if bool(getattr(settings, "LLM_DISABLE_THINKING", False)):
payload["chat_template_kwargs"] = {"enable_thinking": False}
resp = requests.post( resp = requests.post(
f"{api_base}/chat/completions", url,
headers={ headers=headers,
"Authorization": f"Bearer {api_key}", data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
"Content-Type": "application/json", timeout=timeout_sec or int(settings.LLM_HTTP_TIMEOUT_SEC or 120),
},
json={
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
"response_format": {"type": "json_object"},
"stream": True,
**extra_payload,
},
stream=True,
timeout=(connect_timeout_sec, max(60, read_timeout_sec)),
) )
resp.raise_for_status()
data = resp.json()
return _strip_reasoning((data["choices"][0]["message"]["content"] or "").strip())
if resp.status_code in (408, 429, 500, 502, 503, 504):
raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}")
if resp.status_code != 200:
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
resp.encoding = "utf-8" def _extract_json(text: str) -> dict:
chunks: list[str] = [] """从模型输出中解析 JSON object容忍 ```json``` 代码块包裹)。"""
extractor = _ContentFieldStreamExtractor() s = (text or "").strip()
if s.startswith("```"):
s = re.sub(r"^```[a-zA-Z]*\s*", "", s)
s = re.sub(r"\s*```$", "", s).strip()
try: try:
for line in resp.iter_lines(decode_unicode=True): obj = json.loads(s)
if not line: except json.JSONDecodeError:
continue m = re.search(r"\{.*\}", s, flags=re.DOTALL)
s = line.strip() if not m:
if not s.startswith("data:"): return {}
continue try:
payload = s[5:].strip() obj = json.loads(m.group(0))
if not payload or payload == "[DONE]": except json.JSONDecodeError:
break return {}
try: return obj if isinstance(obj, dict) else {}
obj = json.loads(payload)
except Exception:
continue
choices = obj.get("choices")
if not isinstance(choices, list) or not choices:
continue
first = choices[0] if isinstance(choices[0], dict) else {}
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
content = delta.get("content")
if isinstance(content, str) and content:
chunks.append(content)
# print(content, end="", flush=True)
if on_content_delta:
delta_text, done_now = extractor.feed(content)
if delta_text:
try:
on_content_delta("delta", delta_text)
except Exception:
pass
if done_now:
try:
on_content_delta("finalizing", "")
except Exception:
pass
# 兼容部分实现把最终结果放在 message.content
message = first.get("message") if isinstance(first.get("message"), dict) else {}
msg_content = message.get("content")
if isinstance(msg_content, str) and msg_content:
chunks.append(msg_content)
# print(msg_content, end="", flush=True)
if on_content_delta:
delta_text, done_now = extractor.feed(msg_content)
if delta_text:
try:
on_content_delta("delta", delta_text)
except Exception:
pass
if done_now:
try:
on_content_delta("finalizing", "")
except Exception:
pass
except ChunkedEncodingError as e:
partial_text = "".join(chunks).strip()
# 若流提前结束但已收到完整 JSON直接使用避免无谓重试失败。
if partial_text:
try:
parse_json_object_from_text(partial_text)
return partial_text
except Exception:
pass
raise _RetryableLLMError(f"LLM 流中断: {_format_exc_raw(e)}") from e
text = "".join(chunks).strip()
if not text:
raise _RetryableLLMError("LLM 返回空内容")
print()
return text
def chat_completions_json( def chat_completions_json(
*, *,
system_prompt: str, system_prompt: str,
user_prompt: str, user_prompt: str,
temperature: float = 0.2, temperature: float = 0.1,
max_tokens: int = 4096, max_tokens: int = 4096,
timeout_sec: int = 180, timeout_sec: int | None = None,
on_content_delta: Optional[callable] = None, ) -> dict:
log_context: str = "", """调用 LLM 并将返回解析为 JSON objectdict。失败返回 {}"""
) -> dict[str, Any]: try:
""" text = chat_completion_text(
统一的 OpenAI-compat chat/completions 调用强制返回 JSON object system_prompt=system_prompt,
复用项目现有配置LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME user_prompt=user_prompt,
log_context: 调用来源标签如章节编号用于在 generation_trace.log 中区分各次生成调用
"""
api_base = (settings.LLM_API_BASE or "").rstrip("/")
api_key = settings.LLM_API_KEY or ""
model_name = settings.LLM_MODEL_NAME or ""
if not api_base or not api_key or not model_name:
raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME")
ctx = log_context or "-"
_trace_logger.info(
"[输入] context=%s | model=%s | temperature=%s | max_tokens=%s\n"
"----- SYSTEM PROMPT -----\n%s\n"
"----- USER PROMPT -----\n%s\n"
"----- END INPUT -----",
ctx, model_name, temperature, max_tokens, system_prompt, user_prompt,
)
extra_payload: dict[str, Any] = {}
# SiliconFlow 的部分 Qwen 模型默认把输出写到 reasoning_content导致 content 为空;
# 显式关闭 thinking确保最终输出进入 content避免下游解析失败。
if "siliconflow" in api_base.lower() and "qwen" in model_name.lower():
extra_payload["enable_thinking"] = False
final_timeout_sec = int(timeout_sec or 0)
if final_timeout_sec <= 0:
final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90)
retry_count = int(getattr(settings, "LLM_RETRY_COUNT", 2) or 2)
if retry_count < 1:
retry_count = 1
retry_backoff = float(getattr(settings, "LLM_RETRY_BACKOFF_SEC", 1.0) or 1.0)
retry_backoff_max = float(getattr(settings, "LLM_RETRY_BACKOFF_MAX_SEC", 12.0) or 12.0)
connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20)
if connect_timeout_sec <= 0:
connect_timeout_sec = 20
use_stream = True
_logger.info(
"chat_completions_json 调用 | model=%s | temperature=%s | max_tokens=%s | timeout=%s | retry=%s",
model_name, temperature, max_tokens, final_timeout_sec, retry_count,
)
with _llm_slots:
last_err: Optional[Exception] = None
for attempt in range(retry_count):
try:
if use_stream:
content = _chat_completions_stream_text(
api_base=api_base,
api_key=api_key,
model_name=model_name,
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=temperature,
max_tokens=max_tokens,
extra_payload=extra_payload,
connect_timeout_sec=connect_timeout_sec,
read_timeout_sec=final_timeout_sec,
on_content_delta=on_content_delta,
)
else:
# 分离连接超时与读超时:长生成阶段只应占用「读」时间,避免与连接握手混在一个上限里过早超时
read_timeout = max(int(connect_timeout_sec) + 5, int(final_timeout_sec))
resp = requests.post(
f"{api_base}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
"response_format": {"type": "json_object"},
**extra_payload,
},
timeout=(connect_timeout_sec, read_timeout),
)
if resp.status_code in (408, 429, 500, 502, 503, 504):
raise _RetryableLLMError(
f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}"
)
if resp.status_code != 200:
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
data = resp.json()
content = (
(data.get("choices") or [{}])[0]
.get("message", {})
.get("content", "")
)
if not isinstance(content, str) or not content.strip():
raise _RetryableLLMError("LLM 返回空内容")
try:
obj = parse_json_object_from_text(content)
_logger.info(
"chat_completions_json 成功 | model=%s | attempt=%d/%d | content_len=%d | keys=%s",
model_name, attempt + 1, retry_count, len(content), list(obj.keys())[:8],
)
_trace_logger.info(
"[输出] context=%s | model=%s | attempt=%d/%d | output_len=%d\n"
"----- MODEL OUTPUT -----\n%s\n"
"----- END OUTPUT -----",
ctx, model_name, attempt + 1, retry_count, len(content), content,
)
return obj
except ValueError as e:
raise _RetryableLLMError(f"LLM JSON 解析失败: {e}") from e
except (
requests.ReadTimeout,
requests.ConnectTimeout,
requests.ConnectionError,
ChunkedEncodingError,
) as e:
last_err = e
if attempt >= retry_count - 1:
raise RuntimeError(
"LLM 请求超时/连接失败"
f"(已重试{retry_count}timeout={final_timeout_sec}s"
f"endpoint={api_base}/chat/completions"
f"model={model_name}"
f"raw={_format_exc_raw(e)}"
) from e
sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max)
sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2))
time.sleep(sleep_sec)
except _RetryableLLMError as e:
last_err = e
if attempt >= retry_count - 1:
raise RuntimeError(
f"{e}(已重试{retry_count}timeout={final_timeout_sec}s"
f"endpoint={api_base}/chat/completions"
f"model={model_name}"
f"raw={_format_exc_raw(e)}"
) from e
sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max)
sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2))
time.sleep(sleep_sec)
except RequestException as e:
resp = getattr(e, "response", None)
status = getattr(resp, "status_code", None)
body = ""
if resp is not None:
try:
body = (resp.text or "")[:800]
except Exception:
body = ""
raise RuntimeError(
"LLM 请求失败"
f"endpoint={api_base}/chat/completions"
f"model={model_name}"
f"status={status}"
+ (f"body={body}" if body else "")
+ f"raw={_format_exc_raw(e)}"
) from e
else:
raw = _format_exc_raw(last_err) if isinstance(last_err, Exception) else str(last_err)
raise RuntimeError(
"LLM 请求失败"
f"endpoint={api_base}/chat/completions"
f"model={model_name}"
f"raw={raw}"
)
def _repair_loose_json_object(s: str) -> str:
"""常见模型输出问题:尾随逗号(, 后紧跟 } 或 ])。"""
return re.sub(r",(\s*[}\]])", r"\1", s)
def _extract_balanced_json_prefix(s: str) -> str:
"""
提取以 `{` 开始的最长可能完整 JSON 对象前缀
会忽略字符串内的花括号避免误判
"""
start = s.find("{")
if start == -1:
return s
in_string = False
escaped = False
depth = 0
end_idx = -1
for i, ch in enumerate(s[start:], start=start):
if in_string:
if escaped:
escaped = False
elif ch == "\\":
escaped = True
elif ch == '"':
in_string = False
continue
if ch == '"':
in_string = True
elif ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
end_idx = i
break
if end_idx != -1:
return s[start : end_idx + 1]
return s[start:]
def _close_truncated_json_object(s: str) -> str:
"""
处理模型截断导致的 JSON 残缺
- 若字符串未闭合补一个 `"`
- 按栈补齐缺失的 `}` / `]`
"""
out: list[str] = []
stack: list[str] = []
in_string = False
escaped = False
for ch in s:
out.append(ch)
if in_string:
if escaped:
escaped = False
elif ch == "\\":
escaped = True
elif ch == '"':
in_string = False
continue
if ch == '"':
in_string = True
continue
if ch == "{":
stack.append("}")
elif ch == "[":
stack.append("]")
elif ch in ("}", "]"):
if stack and stack[-1] == ch:
stack.pop()
if in_string:
out.append('"')
while stack:
out.append(stack.pop())
return "".join(out)
def parse_json_object_from_text(text: str) -> dict[str, Any]:
"""从模型输出里提取并解析 { ... } JSON 对象。"""
s = (text or "").strip()
s = re.sub(r"```(?:json)?", "", s, flags=re.IGNORECASE).replace("```", "").strip()
start = s.find("{")
if start == -1:
raise ValueError("未找到 JSON 对象")
chunk = s[start:]
balanced_chunk = _extract_balanced_json_prefix(chunk)
decoder = json.JSONDecoder()
last_err: Optional[Exception] = None
for candidate in (
balanced_chunk,
_repair_loose_json_object(balanced_chunk),
_close_truncated_json_object(_repair_loose_json_object(balanced_chunk)),
_close_truncated_json_object(_repair_loose_json_object(chunk)),
):
try:
obj, _ = decoder.raw_decode(candidate)
if not isinstance(obj, dict):
raise ValueError("JSON 根节点不是对象(dict)")
return obj
except json.JSONDecodeError as e:
last_err = e
raise ValueError(f"JSON 解析失败:{last_err}") from last_err
def safe_get_str(v: Any) -> Optional[str]:
if v is None:
return None
s = str(v).strip()
return s if s else None
# ------------------------------------------------------------------
# Agent 多轮对话 + Tool Calling流式生成器
# ------------------------------------------------------------------
def _iter_chat_stream_events(
*,
api_base: str,
api_key: str,
model_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
temperature: float = 0.3,
max_tokens: int = 4096,
extra_payload: dict[str, Any] | None = None,
connect_timeout_sec: int = 20,
read_timeout_sec: int = 300,
):
"""
流式调用 OpenAI-compat /chat/completions逐步 yield 事件
("delta", str) 文本增量
("tool_calls", list) 完整 tool_calls 列表 [{id, function:{name, arguments}}]
("done", dict) 最终 usage 等元信息
"""
payload: dict[str, Any] = {
"model": model_name,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**(extra_payload or {}),
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
resp = requests.post(
f"{api_base}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json=payload,
stream=True,
timeout=(connect_timeout_sec, max(60, read_timeout_sec)),
)
if resp.status_code in (408, 429, 500, 502, 503, 504):
raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}")
if resp.status_code != 200:
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
resp.encoding = "utf-8"
content_parts: list[str] = []
tool_calls_map: dict[int, dict] = {}
for line in resp.iter_lines(decode_unicode=True):
if not line:
continue
s = line.strip()
if not s.startswith("data:"):
continue
data_str = s[5:].strip()
if not data_str or data_str == "[DONE]":
break
try:
obj = json.loads(data_str)
except Exception:
continue
choices = obj.get("choices")
if not isinstance(choices, list) or not choices:
continue
first = choices[0] if isinstance(choices[0], dict) else {}
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
# 仅输出正文 content忽略 reasoning_content避免思考过程展示给用户
_ = delta.get("reasoning_content")
content = delta.get("content")
if isinstance(content, str) and content:
content_parts.append(content)
yield ("delta", content)
# tool_calls delta (streamed incrementally)
tc_deltas = delta.get("tool_calls")
if isinstance(tc_deltas, list):
for tc in tc_deltas:
if not isinstance(tc, dict):
continue
idx = tc.get("index", 0)
if idx not in tool_calls_map:
tool_calls_map[idx] = {
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""},
}
entry = tool_calls_map[idx]
if tc.get("id"):
entry["id"] = tc["id"]
fn = tc.get("function") if isinstance(tc.get("function"), dict) else {}
if fn.get("name"):
entry["function"]["name"] += fn["name"]
if fn.get("arguments"):
entry["function"]["arguments"] += fn["arguments"]
# finish_reason
finish = first.get("finish_reason")
if finish == "tool_calls" and tool_calls_map:
ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())]
yield ("tool_calls", ordered)
tool_calls_map = {}
if tool_calls_map:
ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())]
yield ("tool_calls", ordered)
yield ("done", {"content": "".join(content_parts)})
def _default_disable_thinking_payload(model_name: str) -> dict[str, Any]:
"""Qwen 等推理模型:关闭 thinking仅将最终答案写入 content。"""
if not model_name or "qwen" not in str(model_name).lower():
return {}
return {
"enable_thinking": False,
# vLLM / 部分 OpenAI 兼容网关使用 chat_template_kwargs
"chat_template_kwargs": {"enable_thinking": False},
}
def chat_completions_with_tools(
*,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
temperature: float = 0.3,
max_tokens: int = 4096,
timeout_sec: int = 180,
extra_payload: dict[str, Any] | None = None,
):
"""
Agent 用多轮对话 + tool calling返回生成器yield 事件元组
调用方负责工具循环编排
"""
api_base = (settings.LLM_API_BASE or "").rstrip("/")
api_key = settings.LLM_API_KEY or ""
model_name = settings.LLM_MODEL_NAME or ""
if not api_base or not api_key or not model_name:
raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME")
merged_extra: dict[str, Any] = dict(_default_disable_thinking_payload(model_name))
if extra_payload:
merged_extra.update(extra_payload)
connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20)
if connect_timeout_sec <= 0:
connect_timeout_sec = 20
final_timeout_sec = int(timeout_sec or 0)
if final_timeout_sec <= 0:
final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90)
with _llm_slots:
yield from _iter_chat_stream_events(
api_base=api_base,
api_key=api_key,
model_name=model_name,
messages=messages,
tools=tools,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
extra_payload=merged_extra or None, timeout_sec=timeout_sec,
connect_timeout_sec=connect_timeout_sec,
read_timeout_sec=final_timeout_sec,
) )
except Exception as e: # noqa: BLE001
logger.warning("chat_completions_json 调用失败: %s", e)
return {}
return _extract_json(text)

View File

@ -0,0 +1,406 @@
"""
services/section_extractor.py
Markdown
1) 抽取目录章节标题层级-> 用于生成模板章节目录
2) 按标题拆分正文 -> 每个章节的内容用于入库 report_section_references
抽取/过滤逻辑参考 eval_report/routers/template.py routers/reference.py
"""
from __future__ import annotations
import hashlib
import re
_MAX_SECTION_TITLE_LEN = 200
# ────────────────────────────── 通用过滤/清洗 ──────────────────────────────
def _segment_looks_like_year(segment: str) -> bool:
if not segment.isdigit() or len(segment) != 4:
return False
year = int(segment)
return 1900 <= year <= 2099
def _is_valid_section_number(num: str) -> bool:
"""章节编号形如 1 / 1.1 / 2.3.4排除正文年份2017、2019 等)。"""
parts = [p for p in str(num or "").strip().split(".") if p]
if not parts or not all(p.isdigit() for p in parts):
return False
if any(_segment_looks_like_year(p) for p in parts):
return False
if len(parts) == 1:
return 1 <= int(parts[0]) <= 20
return all(1 <= int(p) <= 99 for p in parts)
def _heading_title_core(rest: str) -> str:
return re.sub(r"^\d+(?:\.\d+)*\s*", "", str(rest or "").strip()).strip()
def _rest_looks_like_body_text(rest: str) -> bool:
"""过滤日期句、长段落、数据说明句等被误识别为标题的正文。"""
t = _heading_title_core(rest) or str(rest or "").strip()
if not t:
return True
if re.match(r"^[月日]", t):
return True
if re.search(r"\d", t):
return True
if re.match(r"^\d{4}\s*年", t) or re.match(r"^\d{4}[、,]", t):
return True
if re.search(r"\d{4}\s*[-~—至]\s*\d{4}", t):
return True
if t.count("") >= 2 or t.count("") >= 2:
return True
if len(t) > 80 and re.search(r"[,。;:]", t):
return True
if len(t) > 45 and any(
k in t
for k in (
"运营数据", "预测数据", "实际运营", "根据公司",
"发展规划", "工况下", "万吨", "有项目", "无项目",
)
):
return True
if len(t) > 45 and not re.search(
r"(评价|分析|结论|概况|说明|措施|建议|对比|控制|实现|状况|情况|程序|模式|评价结论)$",
t.rstrip("。;,"),
):
return True
return False
def _looks_like_real_heading_title(title: str) -> bool:
if not str(title or "").strip():
return False
return not _rest_looks_like_body_text(title)
def _clean_heading_title(s: str) -> str:
t = str(s or "").strip()
t = re.sub(r"\s+", " ", t)
t = re.sub(r"\s+\d+$", "", t).strip() # 去掉目录行尾页码
m_note = re.search(r"[(]([^)]{20,})[)]", t)
if m_note and re.search(r"[,。;:]", m_note.group(1)):
t = re.sub(r"\s*[(][^)]{20,}[)]\s*$", "", t).strip()
return t
def _section_dict(section_key: str, section_title: str) -> dict:
return {"sectionKey": section_key, "sectionTitle": section_title}
def _canonical_to_section_key(canonical: str, order: int) -> str:
return (
re.sub(r"[^a-z0-9\u4e00-\u9fa5]+", "-", canonical).strip("-")
or f"section-{order}"
)
def normalize_section_key(raw_key: str | None, title: str | None) -> str:
"""生成稳定且可入库的 section_key<=64超长追加短哈希。"""
base = (raw_key or "").strip().lower()
if not base:
base = (title or "").strip().lower()
base = re.sub(r"[^a-z0-9\u4e00-\u9fa5]+", "-", base).strip("-")
if not base:
base = "section"
if len(base) <= 64:
return base
digest = hashlib.md5(base.encode("utf-8")).hexdigest()[:10]
prefix = base[:53].rstrip("-")
return f"{prefix}-{digest}"
# ────────────────────────────── 目录TOC抽取 ──────────────────────────────
def _walk_markdown_heading_sections(text: str) -> list[dict]:
"""
单次遍历 Markdown按标题# ~ ######)切分章节并捕获正文(不含本节标题行)。
标题层级自动编号## 项目概况 -> 1.1 项目概况),无显式编号也可处理。
被判定为"非真实标题" # 行视为正文内容,不另起章节。
正文范围
- 默认SECTION_CONTENT_INCLUDE_SUBSECTIONS=True聚合整棵子树
即本节标题之后直到下一个"层级 <= 本节"的标题之前的全部内容
含下级小节标题与正文保证父章节正文非空
- 关闭时仅取到下一个任意标题之前本节自身正文
返回每节{number, title, full_title, canonical, section_key(canonical), level, content}
目录抽取与正文拆分共用此函数确保目录与内容一一对应
"""
from config import settings
include_sub = bool(getattr(settings, "SECTION_CONTENT_INCLUDE_SUBSECTIONS", True))
lines = str(text or "").splitlines()
counters: list[int] = []
accepted: list[dict] = []
seen: set[str] = set()
for idx, raw in enumerate(lines):
m = re.match(r"^(#{1,6})\s+(.+)$", str(raw or "").strip())
if not m:
continue
level = len(m.group(1))
title = _clean_heading_title(m.group(2).strip())
is_valid = (
bool(title)
and len(title) <= _MAX_SECTION_TITLE_LEN
and _looks_like_real_heading_title(title)
)
if not is_valid:
continue
if len(counters) < level:
counters.extend([0] * (level - len(counters)))
else:
counters = counters[:level]
counters[level - 1] += 1
for i in range(level, len(counters)):
counters[i] = 0
num = ".".join(str(counters[i]) for i in range(level))
full_title = f"{num} {title}"
canonical = f"{num}|{title}".lower()
if canonical in seen:
continue
seen.add(canonical)
accepted.append(
{
"number": num,
"title": title,
"full_title": full_title,
"canonical": canonical,
"section_key": _canonical_to_section_key(canonical, len(accepted) + 1),
"level": level,
"start_idx": idx,
}
)
total = len(lines)
for i, sec in enumerate(accepted):
body_start = sec["start_idx"] + 1 # 排除本节标题行
end = total
for j in range(i + 1, len(accepted)):
nxt = accepted[j]
if include_sub:
if nxt["level"] <= sec["level"]:
end = nxt["start_idx"]
break
else:
end = nxt["start_idx"]
break
sec["content"] = "\n".join(lines[body_start:end]).strip()
sec.pop("start_idx", None)
return accepted
def _extract_sections_from_markdown_headings(text: str) -> list[dict]:
"""
Markdown 标题# / ## / ###)构建模板章节目录。
复刻 eval_report 报告模板管理模块 services/template_service.py 的同名逻辑
标题层级自动编号## 项目概况 -> 1.1 项目概况),并过滤非真实标题行。
"""
lines = str(text or "").splitlines()
counters: list[int] = []
out: list[dict] = []
seen: set[str] = set()
for raw in lines:
m = re.match(r"^(#{1,6})\s+(.+)$", str(raw or "").strip())
if not m:
continue
level = len(m.group(1))
title = _clean_heading_title(m.group(2).strip())
if not title or len(title) > _MAX_SECTION_TITLE_LEN:
continue
if not _looks_like_real_heading_title(title):
continue
if len(counters) < level:
counters.extend([0] * (level - len(counters)))
else:
counters = counters[:level]
counters[level - 1] += 1
for i in range(level, len(counters)):
counters[i] = 0
num = ".".join(str(counters[i]) for i in range(level))
full_title = f"{num} {title}"
canonical = f"{num}|{title}".lower()
if canonical in seen:
continue
seen.add(canonical)
out.append(
_section_dict(
_canonical_to_section_key(canonical, len(out) + 1),
full_title,
)
)
return out
def extract_sections_from_text(text: str) -> list[dict]:
"""抽取模板章节目录(入库 report_template_sections
复刻 eval_report 报告模板管理模块的逻辑优先按 Markdown 标题层级识别
命中数 >= 8 时直接采用否则回退到目录/编号行识别"""
md_sections = _extract_sections_from_markdown_headings(text)
if len(md_sections) >= 8:
return md_sections
lines = str(text or "").splitlines()
out: list[dict] = []
seen: set[str] = set()
candidates: list[dict] = []
for raw in lines:
line = str(raw or "").strip()
if not line:
continue
line = re.sub(r"^#{1,6}\s*", "", line).strip()
line = line.replace("\u3000", " ")
line = re.sub(r"\s+", " ", line).strip()
if re.match(r"^20\d{2}\s*年\s*\d{1,2}\s*月$", line):
continue
if line in {"目次", "目录"}:
continue
if re.match(r"^\d+\s*[\)]\s*.+$", line):
continue
has_page_no = bool(re.search(r"\s+\d+\s*$", line))
m = re.match(r"^((?:\d+(?:\.\d+){0,5}))\s*([^\s].*)$", line)
if m:
num = m.group(1).strip()
if not _is_valid_section_number(num):
continue
rest = _clean_heading_title(m.group(2).strip())
if not rest or rest.startswith("") or rest.startswith(")"):
continue
if _rest_looks_like_body_text(rest):
continue
if len(rest) > _MAX_SECTION_TITLE_LEN:
continue
full_title = f"{num} {rest}"[:_MAX_SECTION_TITLE_LEN].rstrip()
canonical = f"{num}|{rest}".lower()
else:
m2 = re.match(r"^([一二三四五六七八九十]+[、.])\s*([^\s].*)$", line)
if not m2:
continue
rest2 = _clean_heading_title(m2.group(2).strip())
if not rest2 or _rest_looks_like_body_text(rest2) or len(rest2) > _MAX_SECTION_TITLE_LEN:
continue
full_title = f"{m2.group(1)} {rest2}"[:_MAX_SECTION_TITLE_LEN].rstrip()
canonical = f"{m2.group(1)}|{rest2}".lower()
candidates.append({"canonical": canonical, "title": full_title, "has_page_no": has_page_no})
use_toc_only = False
toc_rows = [c for c in candidates if c["has_page_no"]]
toc_nums = set()
for c in toc_rows:
m_num = re.match(r"^(\d+)", c["title"])
if m_num:
toc_nums.add(m_num.group(1))
if len(toc_rows) >= 20 and {"1", "2", "3", "4", "5", "6", "7"}.issubset(toc_nums):
use_toc_only = True
picked = toc_rows if use_toc_only else candidates
for c in picked:
canonical = c["canonical"]
if canonical in seen:
continue
if not _looks_like_real_heading_title(c["title"]):
continue
seen.add(canonical)
out.append(_section_dict(_canonical_to_section_key(canonical, len(out) + 1), c["title"]))
return out
# ────────────────────────────── 正文按标题拆分 ──────────────────────────────
def split_markdown_into_sections(text: str) -> list[dict[str, str]]:
"""
Markdown 标题切分正文与目录抽取共用同一套标题识别与自动编号
保证每个目录章节都能拿到对应正文section_key 为自动编号 1.1
若文档不含 Markdown 标题则回退到"带编号标题"的拆分方式
返回 [{section_key, section_title, content}, ...]
"""
walk = _walk_markdown_heading_sections(text)
if walk:
return [
{
"section_key": s["number"],
"section_title": s["full_title"],
"content": s["content"],
}
for s in walk
]
return split_markdown_by_headings(text)
def split_markdown_by_headings(text: str) -> list[dict[str, str]]:
"""
Markdown 标题# ~ ####,且带章节编号,如 ## 1.1 标题)拆分正文。
返回 [{section_key, section_title, content}, ...]section_key 为编号 1.1
"""
lines = str(text or "").splitlines()
heading_pattern = re.compile(r"^#{1,4}\s+(\d+(?:\.\d+)*)\s+(.+)")
sections: list[dict[str, str]] = []
current_key: str | None = None
current_title: str | None = None
current_lines: list[str] = []
for line in lines:
m = heading_pattern.match(line)
if m:
if current_key and current_lines:
sections.append({
"section_key": current_key,
"section_title": current_title or "",
"content": "\n".join(current_lines).strip(),
})
current_key = m.group(1)
current_title = m.group(2).strip()
current_lines = [line]
else:
if current_key:
current_lines.append(line)
if current_key and current_lines:
sections.append({
"section_key": current_key,
"section_title": current_title or "",
"content": "\n".join(current_lines).strip(),
})
return sections
def parse_section_order(section_key: str) -> int:
"""'1.2.1' 转为整数 121 用于排序。"""
digits = str(section_key or "").replace(".", "")
return int(digits) if digits.isdigit() else 0
def clamp_text_bytes(text: str, max_bytes: int, *, suffix: str = "\n…(内容过长,已截断)") -> str:
"""
将文本按 UTF-8 字节数截断到 max_bytes 以内且不会截断到半个字符
用于适配 MySQL TEXT 最大 65535 字节
"""
if not text or max_bytes <= 0:
return text
data = text.encode("utf-8")
if len(data) <= max_bytes:
return text
suffix_bytes = len(suffix.encode("utf-8"))
budget = max(max_bytes - suffix_bytes, 0)
# errors="ignore" 会丢弃末尾被切断的不完整字符,保证是合法 UTF-8
truncated = data[:budget].decode("utf-8", errors="ignore").rstrip()
return truncated + suffix

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,524 @@
"""
services/template_service.py
复刻自 eval_reportreport_template_sections 数据的获取方式
- DEFAULT_TEMPLATE_SECTIONS系统默认后评价报告章节目录key, title
- default_section_prompt / default_section_output_contract / default_section_examples
按章节标题/编号取对应提示词输出合同示例
- build_default_template_catalog默认目录 + 提示词/合同供上传模版匹配
说明eval_report 会额外从编制细则模版Word 文档抽取更细的提示词/示例
本项目默认不含这两个 .doc 文件与 DocParser故相关函数在缺文件时优雅降级
回退到 SECTION_PROMPT_RULES / SECTION_EXAMPLE_RULES
"""
from __future__ import annotations
import re
import uuid
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from sqlalchemy.orm import Session
from database.models import ReportTemplate, ReportTemplateSection
from prompts.report_generation.section_output_contracts import (
DEFAULT_SECTION_OUTPUT_CONTRACT,
SECTION_OUTPUT_CONTRACTS,
)
from prompts.report_generation.template_prompt_rules import (
DEFAULT_SECTION_PROMPT,
SECTION_EXAMPLE_RULES,
SECTION_PROMPT_RULES,
)
SYSTEM_DEFAULT_TEMPLATE_NAME = "后评价默认模板"
GUIDELINE_BASENAME = "炼油化工建设项目后评价报告编制细则(修订)"
PROJECT_EXAMPLE_BASENAME = "模版"
MAX_SECTION_EXAMPLE_CHARS = 12000
DEFAULT_TEMPLATE_SECTIONS: list[tuple[str, str]] = [
("1", "1 项目概况"),
("1-1", "1.1 项目基本情况"),
("1-2", "1.2 项目决策要点"),
("1-3", "1.3 项目实施情况"),
("1-4", "1.4 项目运行情况"),
("2", "2 前期工作评价"),
("2-1", "2.1 项目要素评价"),
("2-1-1", "2.1.1 资源与原料评价"),
("2-1-2", "2.1.2 产品方案及市场评价"),
("2-1-2-1", "2.1.2.1 产品方案评价"),
("2-1-2-2", "2.1.2.2 产品市场评价"),
("2-1-3", "2.1.3 工艺方案评价"),
("2-1-3-1", "2.1.3.1 总加工方案评价"),
("2-1-3-2", "2.1.3.2 建设规模及工艺技术方案评价"),
("2-1-3-3", "2.1.3.3 主要设备方案评价"),
("2-1-4", "2.1.4 厂址选择及外部条件评价"),
("2-1-5", "2.1.5 总图及系统配套工程评价"),
("2-1-6", "2.1.6 主要技术指标评价"),
("2-1-7", "2.1.7 风险分析评价"),
("2-2", "2.2 工作程序评价"),
("2-2-1", "2.2.1 编制单位资质及选择方式评价"),
("2-2-2", "2.2.2 编制进度评价"),
("2-2-3", "2.2.3 与专项评价的结合情况"),
("2-2-4", "2.2.4 可行性研究报告的质量评价"),
("2-3", "2.3 前评估工作评价"),
("2-4", "2.4 初步设计评价"),
("2-4-1", "2.4.1 设计单位资质及选择方式评价"),
("2-4-2", "2.4.2 初步设计进度评价"),
("2-4-3", "2.4.3 初步设计质量评价"),
("2-4-4", "2.4.4 初步设计审查工作评价"),
("2-5", "2.5 前期决策程序评价"),
("2-6", "2.6 前期工作评价结论"),
("3", "3 建设实施评价"),
("3-1", "3.1 工程建设管理模式评价"),
("3-2", "3.2 招投标评价"),
("3-3", "3.3 施工图设计评价"),
("3-3-1", "3.3.1 与批复后初步设计符合性评价"),
("3-3-2", "3.3.2 设计进度评价"),
("3-3-3", "3.3.3 施工图设计水平及质量评价"),
("3-3-4", "3.3.4 施工图设计变更管理评价"),
("3-4", "3.4 工程承包商或施工单位评价"),
("3-4-1", "3.4.1 施工准备评价"),
("3-4-2", "3.4.2 施工计划的执行情况"),
("3-5", "3.5 采购工作评价"),
("3-6", "3.6 工程监理评价"),
("3-7", "3.7 工程质量评价"),
("3-8", "3.8 HSE管理评价"),
("3-9", "3.9 三查四定及中间交接"),
("3-10", "3.10 工程竣工验收评价"),
("3-11", "3.11 建设实施评价结论"),
("4", "4 生产运行评价"),
("4-1", "4.1 生产准备评价"),
("4-2", "4.2 联合试运与试生产情况评价"),
("4-3", "4.3 生产运行评价"),
("4-3-1", "4.3.1 原料供应评价"),
("4-3-2", "4.3.2 生产运行总体情况评价"),
("4-3-3", "4.3.3 达标评价"),
("4-3-4", "4.3.4 生产工艺技术评价"),
("4-3-5", "4.3.5 设备运行评价"),
("4-3-6", "4.3.6 公用工程及辅助设施合理性评价"),
("4-4", "4.4 生产运行评价结论"),
("5", "5 投资与经济效益评价"),
("5-1", "5.1 主要经济指标实现程度评价"),
("5-2", "5.2 投资和执行情况评价"),
("5-2-1", "5.2.1 投资控制及变动原因分析"),
("5-2-2", "5.2.2 投资水平分析"),
("5-2-3", "5.2.3 资金来源及到位评价"),
("5-2-4", "5.2.4 投资控制的经验和教训"),
("5-3", "5.3 经济效益分析"),
("5-3-1", "5.3.1 项目投产以来生产经营及效益状况"),
("5-3-2", "5.3.2 项目经济效益后评价"),
("5-4", "5.4 不确定性分析"),
("5-5", "5.5 投资与经济效益评价结论"),
("6", "6 影响与持续性评价"),
("6-1", "6.1 影响评价"),
("6-1-1", "6.1.1 环境影响评价"),
("6-1-2", "6.1.2 安全影响评价"),
("6-1-3", "6.1.3 科技进步影响"),
("6-1-4", "6.1.4 项目社会影响评价"),
("6-1-5", "6.1.5 项目影响评价结论"),
("6-2", "6.2 持续性评价"),
("6-2-1", "6.2.1 资源分析"),
("6-2-2", "6.2.2 产品分析"),
("6-2-3", "6.2.3 主要技术及经济指标对比"),
("6-2-4", "6.2.4 项目持续性评价结论"),
("7", "7 综合评价结论"),
("7-1", "7.1 综合评价结论"),
("7-1-1", "7.1.1 总体评价结论"),
("7-1-2", "7.1.2 成功度评价"),
("7-2", "7.2 主要经验"),
("7-3", "7.3 问题与建议"),
]
def default_section_output_contract(section_title: str, section_key: str | None = None) -> str:
section_no = _extract_number_prefix(section_title) or _section_key_to_number(section_key)
if section_no and section_no in SECTION_OUTPUT_CONTRACTS:
return SECTION_OUTPUT_CONTRACTS[section_no]
return DEFAULT_SECTION_OUTPUT_CONTRACT
def default_section_prompt(section_title: str, section_key: str | None = None) -> str:
guideline_prompt = _guideline_prompt_for(section_title, section_key)
if guideline_prompt:
return guideline_prompt
title = _normalize_section_identity(section_title)
key = str(section_key or "").strip().lower()
for pattern, prompt in SECTION_PROMPT_RULES:
p = pattern.lower()
if title.startswith(p):
return prompt
if p.isdigit() and (title.startswith(f"{p} ") or key.startswith(f"{p}-") or key == p):
return prompt
return DEFAULT_SECTION_PROMPT
def build_default_template_catalog() -> list[dict[str, str]]:
"""系统默认模板章节目录及对应提示词、输出合同(供上传模版匹配)。"""
out: list[dict[str, str]] = []
for key, title in DEFAULT_TEMPLATE_SECTIONS:
out.append(
{
"sectionKey": key,
"sectionTitle": title,
"sectionNumber": _extract_number_prefix(title) or _section_key_to_number(key),
"sectionPrompt": default_section_prompt(title, key),
"sectionOutputContract": default_section_output_contract(title, key),
}
)
return out
def default_section_examples(section_title: str, section_key: str | None = None) -> str:
project_example = _project_example_for(section_title, section_key)
if project_example:
return project_example
title = _normalize_section_identity(section_title)
key = str(section_key or "").strip().lower()
num = _extract_number_prefix(section_title) or _section_key_to_number(section_key)
chapter_no = ""
if num:
chapter_no = num.split(".")[0]
elif key:
chapter_no = key.split("-")[0]
for prefix, examples in SECTION_EXAMPLE_RULES:
p = str(prefix).strip().lower()
if chapter_no == p:
return examples
if title.startswith(f"{p} "):
return examples
if key.startswith(f"{p}-") or key == p:
return examples
return ""
def _normalize_section_identity(value: str | None) -> str:
text = str(value or "").strip().lower()
text = text.replace("", ".").replace("", ".")
text = re.sub(r"\s+", " ", text)
return text
def _section_key_to_number(section_key: str | None) -> str:
key = str(section_key or "").strip()
if not key:
return ""
if re.fullmatch(r"\d+(?:-\d+)*", key):
return key.replace("-", ".")
return ""
def _extract_number_prefix(title: str) -> str:
m = re.match(r"^\s*(\d+(?:\.\d+)*)\s*", str(title or ""))
return m.group(1) if m else ""
def _normalize_heading_key(value: str) -> str:
s = str(value or "").strip().lower()
s = s.replace("", ".").replace("", ".")
s = re.sub(r"\s+", "", s)
return s
def _tuple_from_number(number_str: str) -> tuple[int, ...]:
if not number_str:
return tuple()
parts = []
for p in number_str.split("."):
if p.isdigit():
parts.append(int(p))
else:
return tuple()
return tuple(parts)
def _read_doc_text(path: str) -> str:
"""读取 .doc/.docx 文本。本项目无 DocParser 时返回空串(优雅降级)。"""
try:
from function.documents.doc_parser import DocParser # type: ignore
except Exception:
return ""
try:
return DocParser(path).read()
except Exception:
return ""
@lru_cache(maxsize=1)
def _guideline_section_prompt_map() -> dict[str, str]:
guideline_path = _resolve_guideline_path()
if not guideline_path:
return {}
raw_text = _read_doc_text(guideline_path)
if not raw_text:
return {}
return _build_guideline_prompt_map(raw_text)
def _resolve_guideline_path() -> str | None:
root = Path(__file__).resolve().parents[1]
candidates = [
root / f"{GUIDELINE_BASENAME}.doc",
root / f"{GUIDELINE_BASENAME}.docx",
]
for p in candidates:
if p.is_file():
return str(p)
return None
def _resolve_project_example_path() -> str | None:
root = Path(__file__).resolve().parents[1]
candidates = [
root / f"{PROJECT_EXAMPLE_BASENAME}.doc",
root / f"{PROJECT_EXAMPLE_BASENAME}.docx",
]
for p in candidates:
if p.is_file():
return str(p)
return None
@lru_cache(maxsize=1)
def _project_example_entries() -> list[tuple[str, str]]:
path = _resolve_project_example_path()
if not path:
return []
raw_text = _read_doc_text(path)
if not raw_text:
return []
return _build_project_example_entries(raw_text)
def _build_project_example_entries(text: str) -> list[tuple[str, str]]:
lines = str(text or "").splitlines()
headings: list[tuple[int, int, str]] = []
for idx, raw in enumerate(lines):
line = str(raw or "").strip()
m = re.match(r"^\s*(#{1,6})\s*(.+?)\s*$", line)
if not m:
continue
level = len(m.group(1))
heading_title = m.group(2).strip()
if not heading_title:
continue
headings.append((idx, level, heading_title))
out: list[tuple[str, str]] = []
for i, (start_idx, level, title) in enumerate(headings):
end_idx = len(lines)
for j in range(i + 1, len(headings)):
next_idx, next_level, _ = headings[j]
if next_level <= level:
end_idx = next_idx
break
body = "\n".join(lines[start_idx + 1 : end_idx]).strip()
body = re.sub(r"\n{3,}", "\n\n", body)
if not body:
continue
out.append((title, body))
return out
def _project_example_for(section_title: str, section_key: str | None = None) -> str:
entries = _project_example_entries()
if not entries:
return ""
target_title = _clean_section_title(section_title)
target_key = _section_key_to_number(section_key)
target_core = _core_title(target_title or target_key)
if not target_core:
return ""
best_title = ""
best_body = ""
best_score = -1
for heading, body in entries:
heading_clean = _clean_section_title(heading)
heading_core = _core_title(heading_clean)
score = _title_match_score(target_core, heading_core)
if score > best_score:
best_score = score
best_title = heading_clean
best_body = body
if best_score < 4 or not best_body:
return ""
text = f"### {best_title}\n\n{best_body}".strip()
if len(text) > MAX_SECTION_EXAMPLE_CHARS:
text = text[:MAX_SECTION_EXAMPLE_CHARS].rstrip() + "\n\n(示例过长,已截断)"
return text
def _clean_section_title(value: str | None) -> str:
s = str(value or "").strip()
s = re.sub(r"^\s*\d+(?:[.\-]\d+)*\s*", "", s)
return s.strip()
def _core_title(value: str | None) -> str:
s = str(value or "").strip()
s = s.replace("", "(").replace("", ")")
s = re.sub(r"\([^)]*\)", "", s)
s = re.sub(r"[、,。;::()\-\s]", "", s)
s = s.replace("项目", "")
s = s.replace("情况", "")
s = s.replace("工作", "")
return s.strip().lower()
def _title_match_score(target: str, candidate: str) -> int:
if not target or not candidate:
return 0
if target == candidate:
return 100
score = 0
if target in candidate or candidate in target:
score += 40
tks_t = re.findall(r"[\u4e00-\u9fa5]{2,8}|[a-z]{2,12}", target)
tks_c = re.findall(r"[\u4e00-\u9fa5]{2,8}|[a-z]{2,12}", candidate)
if tks_t and tks_c:
overlap = len(set(tks_t) & set(tks_c))
score += overlap * 8
ch_overlap = len(set(target) & set(candidate))
score += min(ch_overlap, 20)
return score
def _build_guideline_prompt_map(text: str) -> dict[str, str]:
lines = str(text or "").splitlines()
headings: list[tuple[int, str, str, tuple[int, ...]]] = []
for idx, raw in enumerate(lines):
line = str(raw or "").strip()
m = re.match(r"^\s*#{1,6}\s*(.+?)\s*$", line)
if not m:
continue
heading_title = m.group(1).strip()
number = _extract_number_prefix(heading_title)
number_tuple = _tuple_from_number(number)
if not number_tuple:
continue
headings.append((idx, heading_title, number, number_tuple))
prompt_map: dict[str, str] = {}
for i, (start_idx, heading_title, number, number_tuple) in enumerate(headings):
end_idx = len(lines)
for j in range(i + 1, len(headings)):
next_start, _, _, next_tuple = headings[j]
if len(next_tuple) < len(number_tuple) or next_tuple[: len(number_tuple)] != number_tuple:
end_idx = next_start
break
body = "\n".join(lines[start_idx + 1 : end_idx]).strip()
body = re.sub(r"\n{3,}", "\n\n", body)
if not body:
continue
key_title = _normalize_heading_key(heading_title)
key_number = _normalize_heading_key(number)
prompt_map[key_title] = body
prompt_map[key_number] = body
return prompt_map
def _guideline_prompt_for(section_title: str, section_key: str | None = None) -> str:
mapping = _guideline_section_prompt_map()
if not mapping:
return ""
title = str(section_title or "").strip()
number = _extract_number_prefix(title) or _section_key_to_number(section_key)
candidates = [
_normalize_heading_key(title),
_normalize_heading_key(number),
]
for key in candidates:
if key and key in mapping:
return mapping[key]
return ""
def list_templates(db: Session) -> list[ReportTemplate]:
return (
db.query(ReportTemplate)
.order_by(ReportTemplate.is_default.desc(), ReportTemplate.updated_at.desc())
.all()
)
def ensure_default_template(db: Session) -> None:
now = datetime.now()
system_default = (
db.query(ReportTemplate)
.filter(ReportTemplate.name == SYSTEM_DEFAULT_TEMPLATE_NAME)
.first()
)
if not system_default:
system_default = ReportTemplate(
id=uuid.uuid4().hex,
name=SYSTEM_DEFAULT_TEMPLATE_NAME,
description="系统预置模板(细则完整章节)",
is_default=True,
is_active=True,
created_at=now,
updated_at=now,
)
db.add(system_default)
db.flush()
current_rows = (
db.query(ReportTemplateSection)
.filter(ReportTemplateSection.template_id == system_default.id)
.order_by(ReportTemplateSection.section_order.asc())
.all()
)
current_pairs = [(r.section_key, r.section_title) for r in current_rows]
expected_pairs = list(DEFAULT_TEMPLATE_SECTIONS)
db.query(ReportTemplate).update({ReportTemplate.is_default: False})
system_default.is_default = True
system_default.is_active = True
system_default.updated_at = now
if current_pairs == expected_pairs:
has_changed = False
for row in current_rows:
current_examples = str(row.examples or "").strip()
new_examples = default_section_examples(row.section_title, row.section_key).strip()
if new_examples and current_examples != new_examples:
row.examples = new_examples
row.updated_at = now
has_changed = True
current_out = str(getattr(row, "section_output_contract", None) or "").strip()
new_out = default_section_output_contract(row.section_title, row.section_key).strip()
if not current_out and new_out:
row.section_output_contract = new_out
row.updated_at = now
has_changed = True
if has_changed:
system_default.updated_at = now
db.commit()
return
db.query(ReportTemplateSection).filter(
ReportTemplateSection.template_id == system_default.id
).delete()
for i, (key, title) in enumerate(DEFAULT_TEMPLATE_SECTIONS):
db.add(
ReportTemplateSection(
id=uuid.uuid4().hex,
template_id=system_default.id,
section_key=key,
section_title=title,
section_prompt="",
section_output_contract=default_section_output_contract(title, key),
section_order=i,
examples=default_section_examples(title, key),
created_at=now,
updated_at=now,
)
)
db.commit()