119 lines
3.6 KiB
Python
119 lines
3.6 KiB
Python
"""
|
||
services/llm_client.py
|
||
极简 OpenAI 兼容 Chat Completions 客户端(仅用于生成章节声明,可选)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
|
||
import requests
|
||
|
||
from config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def llm_configured() -> bool:
|
||
return bool(
|
||
str(settings.LLM_API_BASE or "").strip()
|
||
and str(settings.LLM_API_KEY or "").strip()
|
||
and str(settings.LLM_MODEL_NAME or "").strip()
|
||
)
|
||
|
||
|
||
_THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
|
||
|
||
|
||
def _strip_reasoning(text: str) -> str:
|
||
"""去掉思考模型的思维链:成对 <think>…</think>,以及截断/前导的残留标签。"""
|
||
s = text or ""
|
||
s = _THINK_BLOCK_RE.sub("", s)
|
||
# 仅剩结束标签时,说明前面是未配对的思考段,取最后一个 </think> 之后的正文
|
||
if "</think>" in s:
|
||
s = s.rsplit("</think>", 1)[-1]
|
||
s = re.sub(r"</?think>", "", s, flags=re.IGNORECASE)
|
||
return s.strip()
|
||
|
||
|
||
def chat_completion_text(
|
||
*,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
temperature: float = 0.2,
|
||
max_tokens: int = 512,
|
||
timeout_sec: int | None = None,
|
||
) -> str:
|
||
"""调用 LLM 返回纯文本。失败抛出异常,由调用方决定是否兜底。"""
|
||
base = str(settings.LLM_API_BASE or "").strip().rstrip("/")
|
||
url = f"{base}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.LLM_API_KEY}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
payload = {
|
||
"model": settings.LLM_MODEL_NAME,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
}
|
||
# 关闭思考模型的思维链(vLLM/Qwen3 等支持该扩展字段;不支持的服务会忽略)
|
||
if bool(getattr(settings, "LLM_DISABLE_THINKING", False)):
|
||
payload["chat_template_kwargs"] = {"enable_thinking": False}
|
||
resp = requests.post(
|
||
url,
|
||
headers=headers,
|
||
data=json.dumps(payload, ensure_ascii=False).encode("utf-8"),
|
||
timeout=timeout_sec or int(settings.LLM_HTTP_TIMEOUT_SEC or 120),
|
||
)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
return _strip_reasoning((data["choices"][0]["message"]["content"] or "").strip())
|
||
|
||
|
||
def _extract_json(text: str) -> dict:
|
||
"""从模型输出中解析 JSON object(容忍 ```json``` 代码块包裹)。"""
|
||
s = (text or "").strip()
|
||
if s.startswith("```"):
|
||
s = re.sub(r"^```[a-zA-Z]*\s*", "", s)
|
||
s = re.sub(r"\s*```$", "", s).strip()
|
||
try:
|
||
obj = json.loads(s)
|
||
except json.JSONDecodeError:
|
||
m = re.search(r"\{.*\}", s, flags=re.DOTALL)
|
||
if not m:
|
||
return {}
|
||
try:
|
||
obj = json.loads(m.group(0))
|
||
except json.JSONDecodeError:
|
||
return {}
|
||
return obj if isinstance(obj, dict) else {}
|
||
|
||
|
||
def chat_completions_json(
|
||
*,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
temperature: float = 0.1,
|
||
max_tokens: int = 4096,
|
||
timeout_sec: int | None = None,
|
||
) -> dict:
|
||
"""调用 LLM 并将返回解析为 JSON object(dict)。失败返回 {}。"""
|
||
try:
|
||
text = chat_completion_text(
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
timeout_sec=timeout_sec,
|
||
)
|
||
except Exception as e: # noqa: BLE001
|
||
logger.warning("chat_completions_json 调用失败: %s", e)
|
||
return {}
|
||
return _extract_json(text)
|