347 lines
13 KiB
Python
347 lines
13 KiB
Python
"""
|
||
routers/template.py
|
||
报告模板管理:上传文档 → 远程解析为 Markdown → 抽取目录并为每个目录生成声明
|
||
→ 创建模板(目录 + 声明)→ 按章节拆分正文并入库远程 MySQL。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import re
|
||
import tempfile
|
||
import uuid
|
||
from datetime import datetime
|
||
|
||
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
|
||
from sqlalchemy import or_
|
||
from sqlalchemy.orm import Session
|
||
|
||
from database import get_db
|
||
from database.models import (
|
||
ReportSectionReference,
|
||
ReportTemplate,
|
||
ReportTemplateSection,
|
||
)
|
||
from schemas.template import (
|
||
SectionReferenceItem,
|
||
TemplateItem,
|
||
TemplateSectionItem,
|
||
UploadTemplateResult,
|
||
)
|
||
from services.template_prompt_mapper import resolve_uploaded_template_prompts
|
||
from services.desensitize_service import count_masked_numbers, desensitize_content
|
||
from services.file_parse_client import parse_file_to_markdown
|
||
from services.section_extractor import (
|
||
clamp_text_bytes,
|
||
extract_sections_from_text,
|
||
normalize_section_key,
|
||
parse_section_order,
|
||
split_markdown_into_sections,
|
||
)
|
||
|
||
from config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/templates", tags=["报告模板管理"])
|
||
|
||
ALLOWED_SUFFIXES = {".doc", ".docx", ".pdf", ".txt", ".md", ".html", ".htm", ".rtf"}
|
||
|
||
|
||
def _clamp_title(value: str | None) -> str:
|
||
return str(value or "").strip()[:255]
|
||
|
||
|
||
def _build_description(source_file: str | None) -> str:
|
||
sf = str(source_file or "").strip()
|
||
base = "通过文件上传导入"
|
||
return f"{base}\n来源文件:{sf}" if sf else base
|
||
|
||
|
||
def _find_duplicates(db: Session, filename: str) -> dict:
|
||
"""按来源文件名查找已导入的模板与章节范文,用于重复检查。"""
|
||
template_ids = [
|
||
t.id
|
||
for t in db.query(ReportTemplate.id)
|
||
.filter(ReportTemplate.description.like(f"%来源文件:{filename}%"))
|
||
.all()
|
||
]
|
||
ref_count = (
|
||
db.query(ReportSectionReference)
|
||
.filter(ReportSectionReference.source_file == filename)
|
||
.count()
|
||
)
|
||
return {"template_ids": template_ids, "ref_count": ref_count}
|
||
|
||
|
||
def _delete_by_source(db: Session, filename: str, template_ids: list[str]) -> None:
|
||
"""删除指定来源文件已导入的模板(含章节)与章节范文。"""
|
||
for tid in template_ids:
|
||
db.query(ReportTemplateSection).filter(
|
||
ReportTemplateSection.template_id == tid
|
||
).delete(synchronize_session=False)
|
||
db.query(ReportTemplate).filter(ReportTemplate.id == tid).delete(
|
||
synchronize_session=False
|
||
)
|
||
db.query(ReportSectionReference).filter(
|
||
ReportSectionReference.source_file == filename
|
||
).delete(synchronize_session=False)
|
||
db.commit()
|
||
|
||
|
||
def _extract_source_file(description: str | None) -> str | None:
|
||
"""从模板描述中解析来源文件名(上传时写入 '来源文件:xxx')。"""
|
||
m = re.search(r"来源文件\s*[::]\s*(.+)$", str(description or ""))
|
||
if not m:
|
||
return None
|
||
return (m.group(1) or "").strip() or None
|
||
|
||
|
||
def _serialize_template(db: Session, template_id: str) -> TemplateItem:
|
||
t = db.query(ReportTemplate).filter(ReportTemplate.id == template_id).first()
|
||
if not t:
|
||
raise HTTPException(status_code=404, detail="模板不存在")
|
||
sections = (
|
||
db.query(ReportTemplateSection)
|
||
.filter(ReportTemplateSection.template_id == t.id)
|
||
.order_by(ReportTemplateSection.section_order.asc())
|
||
.all()
|
||
)
|
||
src = _extract_source_file(t.description)
|
||
return TemplateItem(
|
||
id=t.id,
|
||
name=t.name,
|
||
description=t.description,
|
||
sourceFile=src,
|
||
createdAt=t.created_at.strftime("%Y-%m-%d %H:%M:%S") if t.created_at else None,
|
||
updatedAt=t.updated_at.strftime("%Y-%m-%d %H:%M:%S") if t.updated_at else None,
|
||
isDefault=t.is_default,
|
||
isActive=t.is_active,
|
||
sections=[
|
||
TemplateSectionItem(
|
||
id=s.id,
|
||
sectionKey=s.section_key,
|
||
sectionTitle=s.section_title,
|
||
sectionPrompt=s.section_prompt,
|
||
sectionOutputContract=s.section_output_contract,
|
||
sectionOrder=s.section_order,
|
||
examples=s.examples,
|
||
)
|
||
for s in sections
|
||
],
|
||
)
|
||
|
||
|
||
@router.post("/upload", response_model=UploadTemplateResult, summary="上传文档并解析为模板(目录+声明)与章节内容")
|
||
async def upload_template_route(
|
||
file: UploadFile = File(...),
|
||
force: bool = Query(False, description="为 true 时覆盖同名来源文件的已有模板与章节后重新导入"),
|
||
db: Session = Depends(get_db),
|
||
):
|
||
filename = (file.filename or "").strip()
|
||
suffix = os.path.splitext(filename)[1].lower()
|
||
if suffix not in ALLOWED_SUFFIXES:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"不支持的文件格式({suffix or '未知'});支持:{', '.join(sorted(ALLOWED_SUFFIXES))}",
|
||
)
|
||
|
||
# 0) 重复检查:同名来源文件已导入则拒绝(force=true 则先删除旧数据再重导)
|
||
dup = _find_duplicates(db, filename)
|
||
if dup["template_ids"] or dup["ref_count"]:
|
||
if not force:
|
||
raise HTTPException(
|
||
status_code=409,
|
||
detail=(
|
||
f"文件「{filename}」已导入(模板 {len(dup['template_ids'])} 个、"
|
||
f"章节 {dup['ref_count']} 条),不可重复入库;"
|
||
f"如需覆盖请使用 force=true 重新上传。"
|
||
),
|
||
)
|
||
logger.info(
|
||
"重复导入覆盖 | file=%s | 删除旧模板=%s | 旧章节=%s",
|
||
filename, len(dup["template_ids"]), dup["ref_count"],
|
||
)
|
||
_delete_by_source(db, filename, dup["template_ids"])
|
||
|
||
# 1) 落盘临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||
tmp_path = tmp.name
|
||
tmp.write(await file.read())
|
||
|
||
try:
|
||
# 2) 远程解析为 Markdown
|
||
try:
|
||
markdown = parse_file_to_markdown(tmp_path)
|
||
except Exception as e: # noqa: BLE001
|
||
raise HTTPException(status_code=400, detail=f"文档解析失败:{e}")
|
||
|
||
if not markdown or not markdown.strip():
|
||
raise HTTPException(status_code=400, detail="解析结果为空")
|
||
|
||
# 3) 抽取目录
|
||
sections = extract_sections_from_text(markdown)
|
||
if not sections:
|
||
raise HTTPException(status_code=400, detail="未从文档中识别到章节标题(目录)")
|
||
|
||
# 4) 按标题拆分正文,对每段正文脱敏(去精确数字等);跳过无正文的章节
|
||
raw_sections = split_markdown_into_sections(markdown)
|
||
max_bytes = int(getattr(settings, "SECTION_CONTENT_MAX_BYTES", 60000) or 60000)
|
||
ref_sections: list[dict] = []
|
||
masked_total = 0
|
||
skipped_empty = 0
|
||
truncated = 0
|
||
for s in raw_sections:
|
||
filtered = desensitize_content(s["content"])
|
||
if not filtered.strip():
|
||
skipped_empty += 1
|
||
continue
|
||
masked_total += max(count_masked_numbers(s["content"], filtered), 0)
|
||
clamped = clamp_text_bytes(filtered, max_bytes)
|
||
if clamped is not filtered and len(clamped) != len(filtered):
|
||
truncated += 1
|
||
ref_sections.append({**s, "content": clamped})
|
||
logger.info(
|
||
"解析结果 | md_len=%s | toc=%s | 入库章节=%s | 跳过空正文=%s | 截断超长=%s | 脱敏数字串=%s",
|
||
len(markdown), len(sections), len(ref_sections), skipped_empty, truncated, masked_total,
|
||
)
|
||
|
||
# 5) 复刻 eval_report:将上传目录匹配默认模板,得到每章节 提示词/输出合同/示例
|
||
# 放到工作线程执行:内部含并行 LLM 调用,避免阻塞事件循环(上传期间仍可并发处理其它请求)
|
||
resolved = await asyncio.to_thread(resolve_uploaded_template_prompts, sections)
|
||
logger.info(
|
||
"提示词匹配完成 | 章节=%s | 命中提示词=%s",
|
||
len(resolved), sum(1 for r in resolved if (r.get("sectionPrompt") or "").strip()),
|
||
)
|
||
|
||
now = datetime.now()
|
||
|
||
# 6) 创建模板(目录 + 提示词/输出合同/示例)
|
||
template = ReportTemplate(
|
||
id=uuid.uuid4().hex,
|
||
name=os.path.splitext(filename)[0] or "上传模板",
|
||
description=_build_description(filename),
|
||
is_default=False,
|
||
is_active=True,
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
db.add(template)
|
||
db.flush()
|
||
|
||
for i, sec in enumerate(sections):
|
||
r = resolved[i] if i < len(resolved) else {}
|
||
db.add(
|
||
ReportTemplateSection(
|
||
id=uuid.uuid4().hex,
|
||
template_id=template.id,
|
||
section_key=normalize_section_key(sec["sectionKey"], sec["sectionTitle"]),
|
||
section_title=_clamp_title(sec["sectionTitle"]),
|
||
section_prompt=(r.get("sectionPrompt") or None),
|
||
section_output_contract=(r.get("sectionOutputContract") or None),
|
||
section_order=i,
|
||
examples="",
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
)
|
||
|
||
# 7) 章节内容入库(report_section_references 格式)
|
||
saved_refs: list[SectionReferenceItem] = []
|
||
for sec in ref_sections:
|
||
ref = ReportSectionReference(
|
||
id=uuid.uuid4().hex,
|
||
template_id=template.id,
|
||
source_file=filename,
|
||
section_key=sec["section_key"],
|
||
section_title=_clamp_title(sec["section_title"]),
|
||
section_order=parse_section_order(sec["section_key"]),
|
||
content=sec["content"],
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
db.add(ref)
|
||
saved_refs.append(
|
||
SectionReferenceItem(
|
||
id=ref.id,
|
||
templateId=ref.template_id,
|
||
sourceFile=ref.source_file,
|
||
sectionKey=ref.section_key,
|
||
sectionTitle=ref.section_title,
|
||
sectionOrder=ref.section_order,
|
||
contentLength=len(ref.content or ""),
|
||
content=ref.content or "",
|
||
)
|
||
)
|
||
|
||
db.commit()
|
||
logger.info(
|
||
"模板上传完成 | file=%s | toc=%s | refs=%s",
|
||
filename, len(sections), len(saved_refs),
|
||
)
|
||
|
||
return UploadTemplateResult(
|
||
template=_serialize_template(db, template.id),
|
||
sourceFile=filename,
|
||
markdownLength=len(markdown),
|
||
totalSections=len(sections),
|
||
totalReferences=len(saved_refs),
|
||
references=saved_refs,
|
||
parseWarnings=[],
|
||
)
|
||
except HTTPException:
|
||
db.rollback()
|
||
raise
|
||
except Exception as e: # noqa: BLE001
|
||
db.rollback()
|
||
raise HTTPException(status_code=400, detail=f"模板创建失败:{e}")
|
||
finally:
|
||
try:
|
||
os.remove(tmp_path)
|
||
except OSError:
|
||
pass
|
||
|
||
|
||
@router.get("", response_model=list[TemplateItem], summary="获取模板列表")
|
||
def list_templates_route(db: Session = Depends(get_db)):
|
||
rows = (
|
||
db.query(ReportTemplate)
|
||
.order_by(ReportTemplate.created_at.desc())
|
||
.all()
|
||
)
|
||
return [_serialize_template(db, r.id) for r in rows]
|
||
|
||
|
||
@router.get("/{template_id}", response_model=TemplateItem, summary="获取模板详情")
|
||
def get_template_route(template_id: str, db: Session = Depends(get_db)):
|
||
return _serialize_template(db, template_id)
|
||
|
||
|
||
@router.delete("/{template_id}", status_code=204, summary="删除模板(含其来源文件的章节范文)")
|
||
def delete_template_route(template_id: str, db: Session = Depends(get_db)):
|
||
t = db.query(ReportTemplate).filter(ReportTemplate.id == template_id).first()
|
||
if not t:
|
||
raise HTTPException(status_code=404, detail="模板不存在")
|
||
source_file = _extract_source_file(t.description)
|
||
db.query(ReportTemplateSection).filter(
|
||
ReportTemplateSection.template_id == t.id
|
||
).delete(synchronize_session=False)
|
||
db.delete(t)
|
||
# 同步删除该模板在 report_section_references 中的章节内容:
|
||
# 优先按 template_id 精确删除;同时按 source_file 兜底清理历史数据(template_id 为空的旧记录)
|
||
conditions = [ReportSectionReference.template_id == t.id]
|
||
if source_file:
|
||
conditions.append(ReportSectionReference.source_file == source_file)
|
||
ref_deleted = (
|
||
db.query(ReportSectionReference)
|
||
.filter(or_(*conditions))
|
||
.delete(synchronize_session=False)
|
||
)
|
||
db.commit()
|
||
logger.info(
|
||
"模板删除完成 | template_id=%s | source_file=%s | 删除章节范文=%s",
|
||
template_id, source_file, ref_deleted,
|
||
)
|