205 lines
6.4 KiB
Python
205 lines
6.4 KiB
Python
"""
|
||
routers/report.py
|
||
后评价报告「核心生成」路由(独立抽取版)。
|
||
|
||
从 eval_report 的 routers/write.py 摘取报告生成相关端点,去除鉴权依赖,
|
||
项目查询改用轻量的 services/project_service.get_project。
|
||
业务逻辑在 services/report_generation_service.py。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
from sqlalchemy.orm import Session
|
||
|
||
from database import SessionLocal, get_db
|
||
from database.models import ReportTemplate, ReportTemplateSection
|
||
from schemas.write import (
|
||
GenerateReportJobCreate,
|
||
GenerateReportJobItem,
|
||
GenerateReportResult,
|
||
)
|
||
from services.project_service import get_project
|
||
from services.report_generation_service import (
|
||
create_report_job,
|
||
get_report_job,
|
||
get_report_result,
|
||
get_report_stream_snapshot,
|
||
retry_report_chapter,
|
||
cancel_report_job,
|
||
)
|
||
|
||
router = APIRouter(prefix="/write", tags=["后评价报告生成"])
|
||
|
||
|
||
@router.get("/projects/{project_id}/generate-sections", summary="按章节智能体生成提示词清单")
|
||
def generate_sections_prompt(
|
||
project_id: str,
|
||
template_id: Optional[str] = None,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
_ = get_project(project_id, db)
|
||
template = None
|
||
if template_id:
|
||
template = db.query(ReportTemplate).filter(ReportTemplate.id == template_id, ReportTemplate.is_active == True).first() # noqa: E712
|
||
if not template:
|
||
template = db.query(ReportTemplate).filter(ReportTemplate.is_default == True, ReportTemplate.is_active == True).first() # noqa: E712
|
||
if not template:
|
||
raise HTTPException(status_code=404, detail="未找到可用模板")
|
||
sections = (
|
||
db.query(ReportTemplateSection)
|
||
.filter(ReportTemplateSection.template_id == template.id)
|
||
.order_by(ReportTemplateSection.section_order.asc())
|
||
.all()
|
||
)
|
||
return {
|
||
"templateId": template.id,
|
||
"templateName": template.name,
|
||
"sections": [
|
||
{
|
||
"sectionKey": s.section_key,
|
||
"sectionTitle": s.section_title,
|
||
"prompt": (
|
||
"请基于2020后评价细则与本项目检索材料,先查要素表,再查文档段落,最后生成本章节内容。\n"
|
||
+ (s.section_prompt or "")
|
||
),
|
||
"examples": s.examples or "",
|
||
}
|
||
for s in sections
|
||
],
|
||
}
|
||
|
||
|
||
@router.post(
|
||
"/projects/{project_id}/generate-report-job",
|
||
response_model=GenerateReportJobItem,
|
||
summary="创建分章异步报告生成任务",
|
||
)
|
||
def create_generate_report_job(
|
||
project_id: str,
|
||
body: GenerateReportJobCreate,
|
||
db: Session = Depends(get_db),
|
||
x_user_id: Optional[str] = Header(default=None, alias="X-User-Id"),
|
||
):
|
||
_ = get_project(project_id, db)
|
||
return create_report_job(
|
||
project_id,
|
||
db,
|
||
template_id=body.templateId,
|
||
top_k=body.topK,
|
||
requested_by=x_user_id,
|
||
)
|
||
|
||
|
||
@router.get(
|
||
"/projects/{project_id}/generate-report-job/{job_id}",
|
||
response_model=GenerateReportJobItem,
|
||
summary="查询分章异步报告任务进度",
|
||
)
|
||
def get_generate_report_job(
|
||
project_id: str,
|
||
job_id: str,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
return get_report_job(project_id, job_id, db)
|
||
|
||
|
||
@router.get(
|
||
"/projects/{project_id}/generate-report-job/{job_id}/result",
|
||
response_model=GenerateReportResult,
|
||
summary="获取分章异步报告任务结果",
|
||
)
|
||
def get_generate_report_result(
|
||
project_id: str,
|
||
job_id: str,
|
||
include_debug: bool = False,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
return get_report_result(project_id, job_id, db, include_debug=include_debug)
|
||
|
||
|
||
@router.get(
|
||
"/projects/{project_id}/generate-report-job/{job_id}/events",
|
||
summary="订阅分章异步报告任务实时事件(SSE)",
|
||
)
|
||
async def stream_generate_report_job_events(
|
||
project_id: str,
|
||
job_id: str,
|
||
include_debug: bool = False,
|
||
):
|
||
# 校验后立即释放连接;SSE 循环中按需短连接查询,避免长连占满连接池
|
||
with SessionLocal() as db:
|
||
_ = get_report_job(project_id, job_id, db)
|
||
|
||
async def _event_stream():
|
||
last_payload = ""
|
||
idle_ticks = 0
|
||
while True:
|
||
snapshot = get_report_stream_snapshot(job_id, include_debug=include_debug)
|
||
if not snapshot:
|
||
with SessionLocal() as db:
|
||
job = get_report_job(project_id, job_id, db)
|
||
result = get_report_result(project_id, job_id, db, include_debug=include_debug)
|
||
snapshot = {
|
||
"job": job.model_dump(),
|
||
"result": result.model_dump(),
|
||
}
|
||
payload = json.dumps(snapshot, ensure_ascii=False, separators=(",", ":"))
|
||
if payload != last_payload:
|
||
last_payload = payload
|
||
idle_ticks = 0
|
||
yield f"event: snapshot\ndata: {payload}\n\n"
|
||
else:
|
||
idle_ticks += 1
|
||
if idle_ticks >= 20:
|
||
idle_ticks = 0
|
||
yield "event: keepalive\ndata: ping\n\n"
|
||
|
||
status = str(((snapshot.get("job") or {}).get("status") or "")).strip().lower()
|
||
if status in ("completed", "failed", "cancelled"):
|
||
yield f"event: end\ndata: {payload}\n\n"
|
||
break
|
||
await asyncio.sleep(0.25)
|
||
|
||
return StreamingResponse(
|
||
_event_stream(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/projects/{project_id}/generate-report-job/{job_id}/retry-chapter",
|
||
response_model=GenerateReportJobItem,
|
||
summary="重试指定章节",
|
||
)
|
||
def retry_generate_report_chapter(
|
||
project_id: str,
|
||
job_id: str,
|
||
section_key: str,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
return retry_report_chapter(project_id, job_id, section_key, db)
|
||
|
||
|
||
@router.post(
|
||
"/projects/{project_id}/generate-report-job/{job_id}/cancel",
|
||
response_model=GenerateReportJobItem,
|
||
summary="取消报告生成任务",
|
||
)
|
||
def cancel_generate_report_job(
|
||
project_id: str,
|
||
job_id: str,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
return cancel_report_job(project_id, job_id, db)
|