""" routers/report.py 后评价报告「核心生成」路由(独立抽取版)。 从 eval_report 的 routers/write.py 摘取报告生成相关端点,去除鉴权依赖, 项目查询改用轻量的 services/project_service.get_project。 业务逻辑在 services/report_generation_service.py。 """ from __future__ import annotations import asyncio import json from typing import Optional from fastapi import APIRouter, Depends, Header, HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from database import SessionLocal, get_db from database.models import ReportTemplate, ReportTemplateSection from schemas.write import ( GenerateReportJobCreate, GenerateReportJobItem, GenerateReportResult, ) from services.project_service import get_project from services.report_generation_service import ( create_report_job, get_report_job, get_report_result, get_report_stream_snapshot, retry_report_chapter, cancel_report_job, ) router = APIRouter(prefix="/write", tags=["后评价报告生成"]) @router.get("/projects/{project_id}/generate-sections", summary="按章节智能体生成提示词清单") def generate_sections_prompt( project_id: str, template_id: Optional[str] = None, db: Session = Depends(get_db), ): _ = get_project(project_id, db) template = None if template_id: template = db.query(ReportTemplate).filter(ReportTemplate.id == template_id, ReportTemplate.is_active == True).first() # noqa: E712 if not template: template = db.query(ReportTemplate).filter(ReportTemplate.is_default == True, ReportTemplate.is_active == True).first() # noqa: E712 if not template: raise HTTPException(status_code=404, detail="未找到可用模板") sections = ( db.query(ReportTemplateSection) .filter(ReportTemplateSection.template_id == template.id) .order_by(ReportTemplateSection.section_order.asc()) .all() ) return { "templateId": template.id, "templateName": template.name, "sections": [ { "sectionKey": s.section_key, "sectionTitle": s.section_title, "prompt": ( "请基于2020后评价细则与本项目检索材料,先查要素表,再查文档段落,最后生成本章节内容。\n" + (s.section_prompt or "") ), "examples": s.examples or "", } for s in sections ], } @router.post( "/projects/{project_id}/generate-report-job", response_model=GenerateReportJobItem, summary="创建分章异步报告生成任务", ) def create_generate_report_job( project_id: str, body: GenerateReportJobCreate, db: Session = Depends(get_db), x_user_id: Optional[str] = Header(default=None, alias="X-User-Id"), ): _ = get_project(project_id, db) return create_report_job( project_id, db, template_id=body.templateId, top_k=body.topK, requested_by=x_user_id, ) @router.get( "/projects/{project_id}/generate-report-job/{job_id}", response_model=GenerateReportJobItem, summary="查询分章异步报告任务进度", ) def get_generate_report_job( project_id: str, job_id: str, db: Session = Depends(get_db), ): return get_report_job(project_id, job_id, db) @router.get( "/projects/{project_id}/generate-report-job/{job_id}/result", response_model=GenerateReportResult, summary="获取分章异步报告任务结果", ) def get_generate_report_result( project_id: str, job_id: str, include_debug: bool = False, db: Session = Depends(get_db), ): return get_report_result(project_id, job_id, db, include_debug=include_debug) @router.get( "/projects/{project_id}/generate-report-job/{job_id}/events", summary="订阅分章异步报告任务实时事件(SSE)", ) async def stream_generate_report_job_events( project_id: str, job_id: str, include_debug: bool = False, ): # 校验后立即释放连接;SSE 循环中按需短连接查询,避免长连占满连接池 with SessionLocal() as db: _ = get_report_job(project_id, job_id, db) async def _event_stream(): last_payload = "" idle_ticks = 0 while True: snapshot = get_report_stream_snapshot(job_id, include_debug=include_debug) if not snapshot: with SessionLocal() as db: job = get_report_job(project_id, job_id, db) result = get_report_result(project_id, job_id, db, include_debug=include_debug) snapshot = { "job": job.model_dump(), "result": result.model_dump(), } payload = json.dumps(snapshot, ensure_ascii=False, separators=(",", ":")) if payload != last_payload: last_payload = payload idle_ticks = 0 yield f"event: snapshot\ndata: {payload}\n\n" else: idle_ticks += 1 if idle_ticks >= 20: idle_ticks = 0 yield "event: keepalive\ndata: ping\n\n" status = str(((snapshot.get("job") or {}).get("status") or "")).strip().lower() if status in ("completed", "failed", "cancelled"): yield f"event: end\ndata: {payload}\n\n" break await asyncio.sleep(0.25) return StreamingResponse( _event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) @router.post( "/projects/{project_id}/generate-report-job/{job_id}/retry-chapter", response_model=GenerateReportJobItem, summary="重试指定章节", ) def retry_generate_report_chapter( project_id: str, job_id: str, section_key: str, db: Session = Depends(get_db), ): return retry_report_chapter(project_id, job_id, section_key, db) @router.post( "/projects/{project_id}/generate-report-job/{job_id}/cancel", response_model=GenerateReportJobItem, summary="取消报告生成任务", ) def cancel_generate_report_job( project_id: str, job_id: str, db: Session = Depends(get_db), ): return cancel_report_job(project_id, job_id, db)