""" routers/template.py 报告模板管理:上传文档 → 远程解析为 Markdown → 抽取目录并为每个目录生成声明 → 创建模板(目录 + 声明)→ 按章节拆分正文并入库远程 MySQL。 """ from __future__ import annotations import asyncio import logging import os import re import tempfile import uuid from datetime import datetime from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile from sqlalchemy import or_ from sqlalchemy.orm import Session from database import get_db from database.models import ( ReportSectionReference, ReportTemplate, ReportTemplateSection, ) from schemas.template import ( SectionReferenceItem, TemplateItem, TemplateSectionItem, UploadTemplateResult, ) from services.template_prompt_mapper import resolve_uploaded_template_prompts from services.desensitize_service import count_masked_numbers, desensitize_content from services.file_parse_client import parse_file_to_markdown from services.section_extractor import ( clamp_text_bytes, extract_sections_from_text, normalize_section_key, parse_section_order, split_markdown_into_sections, ) from config import settings logger = logging.getLogger(__name__) router = APIRouter(prefix="/templates", tags=["报告模板管理"]) ALLOWED_SUFFIXES = {".doc", ".docx", ".pdf", ".txt", ".md", ".html", ".htm", ".rtf"} def _clamp_title(value: str | None) -> str: return str(value or "").strip()[:255] def _build_description(source_file: str | None) -> str: sf = str(source_file or "").strip() base = "通过文件上传导入" return f"{base}\n来源文件:{sf}" if sf else base def _find_duplicates(db: Session, filename: str) -> dict: """按来源文件名查找已导入的模板与章节范文,用于重复检查。""" template_ids = [ t.id for t in db.query(ReportTemplate.id) .filter(ReportTemplate.description.like(f"%来源文件:{filename}%")) .all() ] ref_count = ( db.query(ReportSectionReference) .filter(ReportSectionReference.source_file == filename) .count() ) return {"template_ids": template_ids, "ref_count": ref_count} def _delete_by_source(db: Session, filename: str, template_ids: list[str]) -> None: """删除指定来源文件已导入的模板(含章节)与章节范文。""" for tid in template_ids: db.query(ReportTemplateSection).filter( ReportTemplateSection.template_id == tid ).delete(synchronize_session=False) db.query(ReportTemplate).filter(ReportTemplate.id == tid).delete( synchronize_session=False ) db.query(ReportSectionReference).filter( ReportSectionReference.source_file == filename ).delete(synchronize_session=False) db.commit() def _extract_source_file(description: str | None) -> str | None: """从模板描述中解析来源文件名(上传时写入 '来源文件:xxx')。""" m = re.search(r"来源文件\s*[::]\s*(.+)$", str(description or "")) if not m: return None return (m.group(1) or "").strip() or None def _serialize_template(db: Session, template_id: str) -> TemplateItem: t = db.query(ReportTemplate).filter(ReportTemplate.id == template_id).first() if not t: raise HTTPException(status_code=404, detail="模板不存在") sections = ( db.query(ReportTemplateSection) .filter(ReportTemplateSection.template_id == t.id) .order_by(ReportTemplateSection.section_order.asc()) .all() ) src = _extract_source_file(t.description) return TemplateItem( id=t.id, name=t.name, description=t.description, sourceFile=src, createdAt=t.created_at.strftime("%Y-%m-%d %H:%M:%S") if t.created_at else None, updatedAt=t.updated_at.strftime("%Y-%m-%d %H:%M:%S") if t.updated_at else None, isDefault=t.is_default, isActive=t.is_active, sections=[ TemplateSectionItem( id=s.id, sectionKey=s.section_key, sectionTitle=s.section_title, sectionPrompt=s.section_prompt, sectionOutputContract=s.section_output_contract, sectionOrder=s.section_order, examples=s.examples, ) for s in sections ], ) @router.post("/upload", response_model=UploadTemplateResult, summary="上传文档并解析为模板(目录+声明)与章节内容") async def upload_template_route( file: UploadFile = File(...), force: bool = Query(False, description="为 true 时覆盖同名来源文件的已有模板与章节后重新导入"), db: Session = Depends(get_db), ): filename = (file.filename or "").strip() suffix = os.path.splitext(filename)[1].lower() if suffix not in ALLOWED_SUFFIXES: raise HTTPException( status_code=400, detail=f"不支持的文件格式({suffix or '未知'});支持:{', '.join(sorted(ALLOWED_SUFFIXES))}", ) # 0) 重复检查:同名来源文件已导入则拒绝(force=true 则先删除旧数据再重导) dup = _find_duplicates(db, filename) if dup["template_ids"] or dup["ref_count"]: if not force: raise HTTPException( status_code=409, detail=( f"文件「{filename}」已导入(模板 {len(dup['template_ids'])} 个、" f"章节 {dup['ref_count']} 条),不可重复入库;" f"如需覆盖请使用 force=true 重新上传。" ), ) logger.info( "重复导入覆盖 | file=%s | 删除旧模板=%s | 旧章节=%s", filename, len(dup["template_ids"]), dup["ref_count"], ) _delete_by_source(db, filename, dup["template_ids"]) # 1) 落盘临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp_path = tmp.name tmp.write(await file.read()) try: # 2) 远程解析为 Markdown try: markdown = parse_file_to_markdown(tmp_path) except Exception as e: # noqa: BLE001 raise HTTPException(status_code=400, detail=f"文档解析失败:{e}") if not markdown or not markdown.strip(): raise HTTPException(status_code=400, detail="解析结果为空") # 3) 抽取目录 sections = extract_sections_from_text(markdown) if not sections: raise HTTPException(status_code=400, detail="未从文档中识别到章节标题(目录)") # 4) 按标题拆分正文,对每段正文脱敏(去精确数字等);跳过无正文的章节 raw_sections = split_markdown_into_sections(markdown) max_bytes = int(getattr(settings, "SECTION_CONTENT_MAX_BYTES", 60000) or 60000) ref_sections: list[dict] = [] masked_total = 0 skipped_empty = 0 truncated = 0 for s in raw_sections: filtered = desensitize_content(s["content"]) if not filtered.strip(): skipped_empty += 1 continue masked_total += max(count_masked_numbers(s["content"], filtered), 0) clamped = clamp_text_bytes(filtered, max_bytes) if clamped is not filtered and len(clamped) != len(filtered): truncated += 1 ref_sections.append({**s, "content": clamped}) logger.info( "解析结果 | md_len=%s | toc=%s | 入库章节=%s | 跳过空正文=%s | 截断超长=%s | 脱敏数字串=%s", len(markdown), len(sections), len(ref_sections), skipped_empty, truncated, masked_total, ) # 5) 复刻 eval_report:将上传目录匹配默认模板,得到每章节 提示词/输出合同/示例 # 放到工作线程执行:内部含并行 LLM 调用,避免阻塞事件循环(上传期间仍可并发处理其它请求) resolved = await asyncio.to_thread(resolve_uploaded_template_prompts, sections) logger.info( "提示词匹配完成 | 章节=%s | 命中提示词=%s", len(resolved), sum(1 for r in resolved if (r.get("sectionPrompt") or "").strip()), ) now = datetime.now() # 6) 创建模板(目录 + 提示词/输出合同/示例) template = ReportTemplate( id=uuid.uuid4().hex, name=os.path.splitext(filename)[0] or "上传模板", description=_build_description(filename), is_default=False, is_active=True, created_at=now, updated_at=now, ) db.add(template) db.flush() for i, sec in enumerate(sections): r = resolved[i] if i < len(resolved) else {} db.add( ReportTemplateSection( id=uuid.uuid4().hex, template_id=template.id, section_key=normalize_section_key(sec["sectionKey"], sec["sectionTitle"]), section_title=_clamp_title(sec["sectionTitle"]), section_prompt=(r.get("sectionPrompt") or None), section_output_contract=(r.get("sectionOutputContract") or None), section_order=i, examples="", created_at=now, updated_at=now, ) ) # 7) 章节内容入库(report_section_references 格式) saved_refs: list[SectionReferenceItem] = [] for sec in ref_sections: ref = ReportSectionReference( id=uuid.uuid4().hex, template_id=template.id, source_file=filename, section_key=sec["section_key"], section_title=_clamp_title(sec["section_title"]), section_order=parse_section_order(sec["section_key"]), content=sec["content"], created_at=now, updated_at=now, ) db.add(ref) saved_refs.append( SectionReferenceItem( id=ref.id, templateId=ref.template_id, sourceFile=ref.source_file, sectionKey=ref.section_key, sectionTitle=ref.section_title, sectionOrder=ref.section_order, contentLength=len(ref.content or ""), content=ref.content or "", ) ) db.commit() logger.info( "模板上传完成 | file=%s | toc=%s | refs=%s", filename, len(sections), len(saved_refs), ) return UploadTemplateResult( template=_serialize_template(db, template.id), sourceFile=filename, markdownLength=len(markdown), totalSections=len(sections), totalReferences=len(saved_refs), references=saved_refs, parseWarnings=[], ) except HTTPException: db.rollback() raise except Exception as e: # noqa: BLE001 db.rollback() raise HTTPException(status_code=400, detail=f"模板创建失败:{e}") finally: try: os.remove(tmp_path) except OSError: pass @router.get("", response_model=list[TemplateItem], summary="获取模板列表") def list_templates_route(db: Session = Depends(get_db)): rows = ( db.query(ReportTemplate) .order_by(ReportTemplate.created_at.desc()) .all() ) return [_serialize_template(db, r.id) for r in rows] @router.get("/{template_id}", response_model=TemplateItem, summary="获取模板详情") def get_template_route(template_id: str, db: Session = Depends(get_db)): return _serialize_template(db, template_id) @router.delete("/{template_id}", status_code=204, summary="删除模板(含其来源文件的章节范文)") def delete_template_route(template_id: str, db: Session = Depends(get_db)): t = db.query(ReportTemplate).filter(ReportTemplate.id == template_id).first() if not t: raise HTTPException(status_code=404, detail="模板不存在") source_file = _extract_source_file(t.description) db.query(ReportTemplateSection).filter( ReportTemplateSection.template_id == t.id ).delete(synchronize_session=False) db.delete(t) # 同步删除该模板在 report_section_references 中的章节内容: # 优先按 template_id 精确删除;同时按 source_file 兜底清理历史数据(template_id 为空的旧记录) conditions = [ReportSectionReference.template_id == t.id] if source_file: conditions.append(ReportSectionReference.source_file == source_file) ref_deleted = ( db.query(ReportSectionReference) .filter(or_(*conditions)) .delete(synchronize_session=False) ) db.commit() logger.info( "模板删除完成 | template_id=%s | source_file=%s | 删除章节范文=%s", template_id, source_file, ref_deleted, )