xxy aa98ea2623 @
Initial commit

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@
2026-06-05 18:45:29 +08:00

205 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
routers/report.py
后评价报告「核心生成」路由(独立抽取版)。
从 eval_report 的 routers/write.py 摘取报告生成相关端点,去除鉴权依赖,
项目查询改用轻量的 services/project_service.get_project。
业务逻辑在 services/report_generation_service.py。
"""
from __future__ import annotations
import asyncio
import json
from typing import Optional
from fastapi import APIRouter, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from database import SessionLocal, get_db
from database.models import ReportTemplate, ReportTemplateSection
from schemas.write import (
GenerateReportJobCreate,
GenerateReportJobItem,
GenerateReportResult,
)
from services.project_service import get_project
from services.report_generation_service import (
create_report_job,
get_report_job,
get_report_result,
get_report_stream_snapshot,
retry_report_chapter,
cancel_report_job,
)
router = APIRouter(prefix="/write", tags=["后评价报告生成"])
@router.get("/projects/{project_id}/generate-sections", summary="按章节智能体生成提示词清单")
def generate_sections_prompt(
project_id: str,
template_id: Optional[str] = None,
db: Session = Depends(get_db),
):
_ = get_project(project_id, db)
template = None
if template_id:
template = db.query(ReportTemplate).filter(ReportTemplate.id == template_id, ReportTemplate.is_active == True).first() # noqa: E712
if not template:
template = db.query(ReportTemplate).filter(ReportTemplate.is_default == True, ReportTemplate.is_active == True).first() # noqa: E712
if not template:
raise HTTPException(status_code=404, detail="未找到可用模板")
sections = (
db.query(ReportTemplateSection)
.filter(ReportTemplateSection.template_id == template.id)
.order_by(ReportTemplateSection.section_order.asc())
.all()
)
return {
"templateId": template.id,
"templateName": template.name,
"sections": [
{
"sectionKey": s.section_key,
"sectionTitle": s.section_title,
"prompt": (
"请基于2020后评价细则与本项目检索材料先查要素表再查文档段落最后生成本章节内容。\n"
+ (s.section_prompt or "")
),
"examples": s.examples or "",
}
for s in sections
],
}
@router.post(
"/projects/{project_id}/generate-report-job",
response_model=GenerateReportJobItem,
summary="创建分章异步报告生成任务",
)
def create_generate_report_job(
project_id: str,
body: GenerateReportJobCreate,
db: Session = Depends(get_db),
x_user_id: Optional[str] = Header(default=None, alias="X-User-Id"),
):
_ = get_project(project_id, db)
return create_report_job(
project_id,
db,
template_id=body.templateId,
top_k=body.topK,
requested_by=x_user_id,
)
@router.get(
"/projects/{project_id}/generate-report-job/{job_id}",
response_model=GenerateReportJobItem,
summary="查询分章异步报告任务进度",
)
def get_generate_report_job(
project_id: str,
job_id: str,
db: Session = Depends(get_db),
):
return get_report_job(project_id, job_id, db)
@router.get(
"/projects/{project_id}/generate-report-job/{job_id}/result",
response_model=GenerateReportResult,
summary="获取分章异步报告任务结果",
)
def get_generate_report_result(
project_id: str,
job_id: str,
include_debug: bool = False,
db: Session = Depends(get_db),
):
return get_report_result(project_id, job_id, db, include_debug=include_debug)
@router.get(
"/projects/{project_id}/generate-report-job/{job_id}/events",
summary="订阅分章异步报告任务实时事件SSE",
)
async def stream_generate_report_job_events(
project_id: str,
job_id: str,
include_debug: bool = False,
):
# 校验后立即释放连接SSE 循环中按需短连接查询,避免长连占满连接池
with SessionLocal() as db:
_ = get_report_job(project_id, job_id, db)
async def _event_stream():
last_payload = ""
idle_ticks = 0
while True:
snapshot = get_report_stream_snapshot(job_id, include_debug=include_debug)
if not snapshot:
with SessionLocal() as db:
job = get_report_job(project_id, job_id, db)
result = get_report_result(project_id, job_id, db, include_debug=include_debug)
snapshot = {
"job": job.model_dump(),
"result": result.model_dump(),
}
payload = json.dumps(snapshot, ensure_ascii=False, separators=(",", ":"))
if payload != last_payload:
last_payload = payload
idle_ticks = 0
yield f"event: snapshot\ndata: {payload}\n\n"
else:
idle_ticks += 1
if idle_ticks >= 20:
idle_ticks = 0
yield "event: keepalive\ndata: ping\n\n"
status = str(((snapshot.get("job") or {}).get("status") or "")).strip().lower()
if status in ("completed", "failed", "cancelled"):
yield f"event: end\ndata: {payload}\n\n"
break
await asyncio.sleep(0.25)
return StreamingResponse(
_event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.post(
"/projects/{project_id}/generate-report-job/{job_id}/retry-chapter",
response_model=GenerateReportJobItem,
summary="重试指定章节",
)
def retry_generate_report_chapter(
project_id: str,
job_id: str,
section_key: str,
db: Session = Depends(get_db),
):
return retry_report_chapter(project_id, job_id, section_key, db)
@router.post(
"/projects/{project_id}/generate-report-job/{job_id}/cancel",
response_model=GenerateReportJobItem,
summary="取消报告生成任务",
)
def cancel_generate_report_job(
project_id: str,
job_id: str,
db: Session = Depends(get_db),
):
return cancel_report_job(project_id, job_id, db)