report_generation/services/llm_client.py
xxy aa98ea2623 @
Initial commit

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@
2026-06-05 18:45:29 +08:00

725 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import logging
import random
import re
import time
import threading
from typing import Any, Optional
import requests
from requests import RequestException
from requests.exceptions import ChunkedEncodingError
from config import settings
_logger = logging.getLogger(__name__)
# 生成全过程追踪:完整记录输入 prompt / 调用模型 / 模型输出,写入 logs/generation_trace.log
_trace_logger = logging.getLogger("generation.trace")
_LLM_MAX_CONCURRENCY = 5
_llm_slots = threading.BoundedSemaphore(_LLM_MAX_CONCURRENCY)
class _RetryableLLMError(RuntimeError):
"""用于标记可安全重试的 LLM 调用异常。"""
class _ContentFieldStreamExtractor:
"""从流式 JSON 文本中增量提取 content 字段的已解码正文。"""
def __init__(self) -> None:
self._raw = ""
self._content_started = False
self._content_done = False
self._value_start = -1
self._consumed_pos = 0
def feed(self, chunk: str) -> tuple[str, bool]:
if not chunk:
return "", False
self._raw += chunk
emitted = ""
done_now = False
if not self._content_started:
marker = '"content"'
idx = self._raw.find(marker)
if idx == -1:
return "", False
colon = self._raw.find(":", idx + len(marker))
if colon == -1:
return "", False
quote = self._raw.find('"', colon + 1)
if quote == -1:
return "", False
self._content_started = True
self._value_start = quote + 1
self._consumed_pos = self._value_start
if self._content_started and not self._content_done:
emitted, consumed_pos, done_now = self._decode_partial_json_string(
self._raw,
self._consumed_pos,
)
self._consumed_pos = consumed_pos
if done_now:
self._content_done = True
return emitted, done_now
@staticmethod
def _decode_partial_json_string(src: str, start: int) -> tuple[str, int, bool]:
out: list[str] = []
i = start
n = len(src)
while i < n:
ch = src[i]
if ch == '"':
if i == 0 or src[i - 1] != "\\" or _ContentFieldStreamExtractor._is_escaped(src, i):
return "".join(out), i, True
if ch != "\\":
out.append(ch)
i += 1
continue
if i + 1 >= n:
break
esc = src[i + 1]
mapping = {
'"': '"',
"\\": "\\",
"/": "/",
"b": "\b",
"f": "\f",
"n": "\n",
"r": "\r",
"t": "\t",
}
if esc == "u":
if i + 6 > n:
break
hex_part = src[i + 2 : i + 6]
try:
out.append(chr(int(hex_part, 16)))
except Exception:
pass
i += 6
continue
if esc in mapping:
out.append(mapping[esc])
i += 2
continue
out.append(esc)
i += 2
return "".join(out), i, False
@staticmethod
def _is_escaped(src: str, quote_index: int) -> bool:
backslashes = 0
i = quote_index - 1
while i >= 0 and src[i] == "\\":
backslashes += 1
i -= 1
return backslashes % 2 == 0
def _format_exc_raw(e: Exception) -> str:
"""统一输出最直接的异常原文(类型 + repr"""
return f"{type(e).__name__}: {e!r}"
def _chat_completions_stream_text(
*,
api_base: str,
api_key: str,
model_name: str,
system_prompt: str,
user_prompt: str,
temperature: float,
max_tokens: int,
extra_payload: dict[str, Any],
connect_timeout_sec: int,
read_timeout_sec: int = 300,
on_content_delta: Optional[callable] = None,
) -> str:
"""
以 OpenAI-compat SSE 流式读取模型输出文本。
- connect timeout 保留,避免连接阶段长时间卡死
- read timeout 防止流式读取无限挂起(默认 300s
"""
_logger.info(
"LLM 流式调用开始 | model=%s | temperature=%s | max_tokens=%s | timeout_connect=%s timeout_read=%s",
model_name, temperature, max_tokens, connect_timeout_sec, read_timeout_sec,
)
resp = requests.post(
f"{api_base}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
"response_format": {"type": "json_object"},
"stream": True,
**extra_payload,
},
stream=True,
timeout=(connect_timeout_sec, max(60, read_timeout_sec)),
)
if resp.status_code in (408, 429, 500, 502, 503, 504):
raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}")
if resp.status_code != 200:
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
resp.encoding = "utf-8"
chunks: list[str] = []
extractor = _ContentFieldStreamExtractor()
try:
for line in resp.iter_lines(decode_unicode=True):
if not line:
continue
s = line.strip()
if not s.startswith("data:"):
continue
payload = s[5:].strip()
if not payload or payload == "[DONE]":
break
try:
obj = json.loads(payload)
except Exception:
continue
choices = obj.get("choices")
if not isinstance(choices, list) or not choices:
continue
first = choices[0] if isinstance(choices[0], dict) else {}
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
content = delta.get("content")
if isinstance(content, str) and content:
chunks.append(content)
# print(content, end="", flush=True)
if on_content_delta:
delta_text, done_now = extractor.feed(content)
if delta_text:
try:
on_content_delta("delta", delta_text)
except Exception:
pass
if done_now:
try:
on_content_delta("finalizing", "")
except Exception:
pass
# 兼容部分实现把最终结果放在 message.content
message = first.get("message") if isinstance(first.get("message"), dict) else {}
msg_content = message.get("content")
if isinstance(msg_content, str) and msg_content:
chunks.append(msg_content)
# print(msg_content, end="", flush=True)
if on_content_delta:
delta_text, done_now = extractor.feed(msg_content)
if delta_text:
try:
on_content_delta("delta", delta_text)
except Exception:
pass
if done_now:
try:
on_content_delta("finalizing", "")
except Exception:
pass
except ChunkedEncodingError as e:
partial_text = "".join(chunks).strip()
# 若流提前结束但已收到完整 JSON直接使用避免无谓重试失败。
if partial_text:
try:
parse_json_object_from_text(partial_text)
return partial_text
except Exception:
pass
raise _RetryableLLMError(f"LLM 流中断: {_format_exc_raw(e)}") from e
text = "".join(chunks).strip()
if not text:
raise _RetryableLLMError("LLM 返回空内容")
print()
return text
def chat_completions_json(
*,
system_prompt: str,
user_prompt: str,
temperature: float = 0.2,
max_tokens: int = 4096,
timeout_sec: int = 180,
on_content_delta: Optional[callable] = None,
log_context: str = "",
) -> dict[str, Any]:
"""
统一的 OpenAI-compat chat/completions 调用,强制返回 JSON object。
复用项目现有配置LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME。
log_context: 调用来源标签(如章节编号),用于在 generation_trace.log 中区分各次生成调用。
"""
api_base = (settings.LLM_API_BASE or "").rstrip("/")
api_key = settings.LLM_API_KEY or ""
model_name = settings.LLM_MODEL_NAME or ""
if not api_base or not api_key or not model_name:
raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME")
ctx = log_context or "-"
_trace_logger.info(
"[输入] context=%s | model=%s | temperature=%s | max_tokens=%s\n"
"----- SYSTEM PROMPT -----\n%s\n"
"----- USER PROMPT -----\n%s\n"
"----- END INPUT -----",
ctx, model_name, temperature, max_tokens, system_prompt, user_prompt,
)
extra_payload: dict[str, Any] = {}
# SiliconFlow 的部分 Qwen 模型默认把输出写到 reasoning_content导致 content 为空;
# 显式关闭 thinking确保最终输出进入 content避免下游解析失败。
if "siliconflow" in api_base.lower() and "qwen" in model_name.lower():
extra_payload["enable_thinking"] = False
final_timeout_sec = int(timeout_sec or 0)
if final_timeout_sec <= 0:
final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90)
retry_count = int(getattr(settings, "LLM_RETRY_COUNT", 2) or 2)
if retry_count < 1:
retry_count = 1
retry_backoff = float(getattr(settings, "LLM_RETRY_BACKOFF_SEC", 1.0) or 1.0)
retry_backoff_max = float(getattr(settings, "LLM_RETRY_BACKOFF_MAX_SEC", 12.0) or 12.0)
connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20)
if connect_timeout_sec <= 0:
connect_timeout_sec = 20
use_stream = True
_logger.info(
"chat_completions_json 调用 | model=%s | temperature=%s | max_tokens=%s | timeout=%s | retry=%s",
model_name, temperature, max_tokens, final_timeout_sec, retry_count,
)
with _llm_slots:
last_err: Optional[Exception] = None
for attempt in range(retry_count):
try:
if use_stream:
content = _chat_completions_stream_text(
api_base=api_base,
api_key=api_key,
model_name=model_name,
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=temperature,
max_tokens=max_tokens,
extra_payload=extra_payload,
connect_timeout_sec=connect_timeout_sec,
read_timeout_sec=final_timeout_sec,
on_content_delta=on_content_delta,
)
else:
# 分离连接超时与读超时:长生成阶段只应占用「读」时间,避免与连接握手混在一个上限里过早超时
read_timeout = max(int(connect_timeout_sec) + 5, int(final_timeout_sec))
resp = requests.post(
f"{api_base}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": model_name,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
"response_format": {"type": "json_object"},
**extra_payload,
},
timeout=(connect_timeout_sec, read_timeout),
)
if resp.status_code in (408, 429, 500, 502, 503, 504):
raise _RetryableLLMError(
f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}"
)
if resp.status_code != 200:
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
data = resp.json()
content = (
(data.get("choices") or [{}])[0]
.get("message", {})
.get("content", "")
)
if not isinstance(content, str) or not content.strip():
raise _RetryableLLMError("LLM 返回空内容")
try:
obj = parse_json_object_from_text(content)
_logger.info(
"chat_completions_json 成功 | model=%s | attempt=%d/%d | content_len=%d | keys=%s",
model_name, attempt + 1, retry_count, len(content), list(obj.keys())[:8],
)
_trace_logger.info(
"[输出] context=%s | model=%s | attempt=%d/%d | output_len=%d\n"
"----- MODEL OUTPUT -----\n%s\n"
"----- END OUTPUT -----",
ctx, model_name, attempt + 1, retry_count, len(content), content,
)
return obj
except ValueError as e:
raise _RetryableLLMError(f"LLM JSON 解析失败: {e}") from e
except (
requests.ReadTimeout,
requests.ConnectTimeout,
requests.ConnectionError,
ChunkedEncodingError,
) as e:
last_err = e
if attempt >= retry_count - 1:
raise RuntimeError(
"LLM 请求超时/连接失败"
f"(已重试{retry_count}timeout={final_timeout_sec}s"
f"endpoint={api_base}/chat/completions"
f"model={model_name}"
f"raw={_format_exc_raw(e)}"
) from e
sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max)
sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2))
time.sleep(sleep_sec)
except _RetryableLLMError as e:
last_err = e
if attempt >= retry_count - 1:
raise RuntimeError(
f"{e}(已重试{retry_count}timeout={final_timeout_sec}s"
f"endpoint={api_base}/chat/completions"
f"model={model_name}"
f"raw={_format_exc_raw(e)}"
) from e
sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max)
sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2))
time.sleep(sleep_sec)
except RequestException as e:
resp = getattr(e, "response", None)
status = getattr(resp, "status_code", None)
body = ""
if resp is not None:
try:
body = (resp.text or "")[:800]
except Exception:
body = ""
raise RuntimeError(
"LLM 请求失败"
f"endpoint={api_base}/chat/completions"
f"model={model_name}"
f"status={status}"
+ (f"body={body}" if body else "")
+ f"raw={_format_exc_raw(e)}"
) from e
else:
raw = _format_exc_raw(last_err) if isinstance(last_err, Exception) else str(last_err)
raise RuntimeError(
"LLM 请求失败"
f"endpoint={api_base}/chat/completions"
f"model={model_name}"
f"raw={raw}"
)
def _repair_loose_json_object(s: str) -> str:
"""常见模型输出问题:尾随逗号(, 后紧跟 } 或 ])。"""
return re.sub(r",(\s*[}\]])", r"\1", s)
def _extract_balanced_json_prefix(s: str) -> str:
"""
提取以 `{` 开始的最长“可能完整”的 JSON 对象前缀。
会忽略字符串内的花括号,避免误判。
"""
start = s.find("{")
if start == -1:
return s
in_string = False
escaped = False
depth = 0
end_idx = -1
for i, ch in enumerate(s[start:], start=start):
if in_string:
if escaped:
escaped = False
elif ch == "\\":
escaped = True
elif ch == '"':
in_string = False
continue
if ch == '"':
in_string = True
elif ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
end_idx = i
break
if end_idx != -1:
return s[start : end_idx + 1]
return s[start:]
def _close_truncated_json_object(s: str) -> str:
"""
处理模型截断导致的 JSON 残缺:
- 若字符串未闭合,补一个 `"`
- 按栈补齐缺失的 `}` / `]`
"""
out: list[str] = []
stack: list[str] = []
in_string = False
escaped = False
for ch in s:
out.append(ch)
if in_string:
if escaped:
escaped = False
elif ch == "\\":
escaped = True
elif ch == '"':
in_string = False
continue
if ch == '"':
in_string = True
continue
if ch == "{":
stack.append("}")
elif ch == "[":
stack.append("]")
elif ch in ("}", "]"):
if stack and stack[-1] == ch:
stack.pop()
if in_string:
out.append('"')
while stack:
out.append(stack.pop())
return "".join(out)
def parse_json_object_from_text(text: str) -> dict[str, Any]:
"""从模型输出里提取并解析 { ... } JSON 对象。"""
s = (text or "").strip()
s = re.sub(r"```(?:json)?", "", s, flags=re.IGNORECASE).replace("```", "").strip()
start = s.find("{")
if start == -1:
raise ValueError("未找到 JSON 对象")
chunk = s[start:]
balanced_chunk = _extract_balanced_json_prefix(chunk)
decoder = json.JSONDecoder()
last_err: Optional[Exception] = None
for candidate in (
balanced_chunk,
_repair_loose_json_object(balanced_chunk),
_close_truncated_json_object(_repair_loose_json_object(balanced_chunk)),
_close_truncated_json_object(_repair_loose_json_object(chunk)),
):
try:
obj, _ = decoder.raw_decode(candidate)
if not isinstance(obj, dict):
raise ValueError("JSON 根节点不是对象(dict)")
return obj
except json.JSONDecodeError as e:
last_err = e
raise ValueError(f"JSON 解析失败:{last_err}") from last_err
def safe_get_str(v: Any) -> Optional[str]:
if v is None:
return None
s = str(v).strip()
return s if s else None
# ------------------------------------------------------------------
# Agent 多轮对话 + Tool Calling流式生成器
# ------------------------------------------------------------------
def _iter_chat_stream_events(
*,
api_base: str,
api_key: str,
model_name: str,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
temperature: float = 0.3,
max_tokens: int = 4096,
extra_payload: dict[str, Any] | None = None,
connect_timeout_sec: int = 20,
read_timeout_sec: int = 300,
):
"""
流式调用 OpenAI-compat /chat/completions逐步 yield 事件:
("delta", str) — 文本增量
("tool_calls", list) — 完整 tool_calls 列表 [{id, function:{name, arguments}}]
("done", dict) — 最终 usage 等元信息
"""
payload: dict[str, Any] = {
"model": model_name,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
**(extra_payload or {}),
}
if tools:
payload["tools"] = tools
payload["tool_choice"] = "auto"
resp = requests.post(
f"{api_base}/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json=payload,
stream=True,
timeout=(connect_timeout_sec, max(60, read_timeout_sec)),
)
if resp.status_code in (408, 429, 500, 502, 503, 504):
raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}")
if resp.status_code != 200:
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
resp.encoding = "utf-8"
content_parts: list[str] = []
tool_calls_map: dict[int, dict] = {}
for line in resp.iter_lines(decode_unicode=True):
if not line:
continue
s = line.strip()
if not s.startswith("data:"):
continue
data_str = s[5:].strip()
if not data_str or data_str == "[DONE]":
break
try:
obj = json.loads(data_str)
except Exception:
continue
choices = obj.get("choices")
if not isinstance(choices, list) or not choices:
continue
first = choices[0] if isinstance(choices[0], dict) else {}
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
# 仅输出正文 content忽略 reasoning_content避免思考过程展示给用户
_ = delta.get("reasoning_content")
content = delta.get("content")
if isinstance(content, str) and content:
content_parts.append(content)
yield ("delta", content)
# tool_calls delta (streamed incrementally)
tc_deltas = delta.get("tool_calls")
if isinstance(tc_deltas, list):
for tc in tc_deltas:
if not isinstance(tc, dict):
continue
idx = tc.get("index", 0)
if idx not in tool_calls_map:
tool_calls_map[idx] = {
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""},
}
entry = tool_calls_map[idx]
if tc.get("id"):
entry["id"] = tc["id"]
fn = tc.get("function") if isinstance(tc.get("function"), dict) else {}
if fn.get("name"):
entry["function"]["name"] += fn["name"]
if fn.get("arguments"):
entry["function"]["arguments"] += fn["arguments"]
# finish_reason
finish = first.get("finish_reason")
if finish == "tool_calls" and tool_calls_map:
ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())]
yield ("tool_calls", ordered)
tool_calls_map = {}
if tool_calls_map:
ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())]
yield ("tool_calls", ordered)
yield ("done", {"content": "".join(content_parts)})
def _default_disable_thinking_payload(model_name: str) -> dict[str, Any]:
"""Qwen 等推理模型:关闭 thinking仅将最终答案写入 content。"""
if not model_name or "qwen" not in str(model_name).lower():
return {}
return {
"enable_thinking": False,
# vLLM / 部分 OpenAI 兼容网关使用 chat_template_kwargs
"chat_template_kwargs": {"enable_thinking": False},
}
def chat_completions_with_tools(
*,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
temperature: float = 0.3,
max_tokens: int = 4096,
timeout_sec: int = 180,
extra_payload: dict[str, Any] | None = None,
):
"""
Agent 用多轮对话 + tool calling。返回生成器yield 事件元组。
调用方负责工具循环编排。
"""
api_base = (settings.LLM_API_BASE or "").rstrip("/")
api_key = settings.LLM_API_KEY or ""
model_name = settings.LLM_MODEL_NAME or ""
if not api_base or not api_key or not model_name:
raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME")
merged_extra: dict[str, Any] = dict(_default_disable_thinking_payload(model_name))
if extra_payload:
merged_extra.update(extra_payload)
connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20)
if connect_timeout_sec <= 0:
connect_timeout_sec = 20
final_timeout_sec = int(timeout_sec or 0)
if final_timeout_sec <= 0:
final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90)
with _llm_slots:
yield from _iter_chat_stream_events(
api_base=api_base,
api_key=api_key,
model_name=model_name,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
extra_payload=merged_extra or None,
connect_timeout_sec=connect_timeout_sec,
read_timeout_sec=final_timeout_sec,
)