725 lines
26 KiB
Python
725 lines
26 KiB
Python
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import random
|
||
import re
|
||
import time
|
||
import threading
|
||
from typing import Any, Optional
|
||
|
||
import requests
|
||
from requests import RequestException
|
||
from requests.exceptions import ChunkedEncodingError
|
||
|
||
from config import settings
|
||
|
||
_logger = logging.getLogger(__name__)
|
||
# 生成全过程追踪:完整记录输入 prompt / 调用模型 / 模型输出,写入 logs/generation_trace.log
|
||
_trace_logger = logging.getLogger("generation.trace")
|
||
|
||
_LLM_MAX_CONCURRENCY = 5
|
||
_llm_slots = threading.BoundedSemaphore(_LLM_MAX_CONCURRENCY)
|
||
|
||
|
||
class _RetryableLLMError(RuntimeError):
|
||
"""用于标记可安全重试的 LLM 调用异常。"""
|
||
|
||
|
||
class _ContentFieldStreamExtractor:
|
||
"""从流式 JSON 文本中增量提取 content 字段的已解码正文。"""
|
||
|
||
def __init__(self) -> None:
|
||
self._raw = ""
|
||
self._content_started = False
|
||
self._content_done = False
|
||
self._value_start = -1
|
||
self._consumed_pos = 0
|
||
|
||
def feed(self, chunk: str) -> tuple[str, bool]:
|
||
if not chunk:
|
||
return "", False
|
||
self._raw += chunk
|
||
emitted = ""
|
||
done_now = False
|
||
|
||
if not self._content_started:
|
||
marker = '"content"'
|
||
idx = self._raw.find(marker)
|
||
if idx == -1:
|
||
return "", False
|
||
colon = self._raw.find(":", idx + len(marker))
|
||
if colon == -1:
|
||
return "", False
|
||
quote = self._raw.find('"', colon + 1)
|
||
if quote == -1:
|
||
return "", False
|
||
self._content_started = True
|
||
self._value_start = quote + 1
|
||
self._consumed_pos = self._value_start
|
||
|
||
if self._content_started and not self._content_done:
|
||
emitted, consumed_pos, done_now = self._decode_partial_json_string(
|
||
self._raw,
|
||
self._consumed_pos,
|
||
)
|
||
self._consumed_pos = consumed_pos
|
||
if done_now:
|
||
self._content_done = True
|
||
return emitted, done_now
|
||
|
||
@staticmethod
|
||
def _decode_partial_json_string(src: str, start: int) -> tuple[str, int, bool]:
|
||
out: list[str] = []
|
||
i = start
|
||
n = len(src)
|
||
while i < n:
|
||
ch = src[i]
|
||
if ch == '"':
|
||
if i == 0 or src[i - 1] != "\\" or _ContentFieldStreamExtractor._is_escaped(src, i):
|
||
return "".join(out), i, True
|
||
if ch != "\\":
|
||
out.append(ch)
|
||
i += 1
|
||
continue
|
||
if i + 1 >= n:
|
||
break
|
||
esc = src[i + 1]
|
||
mapping = {
|
||
'"': '"',
|
||
"\\": "\\",
|
||
"/": "/",
|
||
"b": "\b",
|
||
"f": "\f",
|
||
"n": "\n",
|
||
"r": "\r",
|
||
"t": "\t",
|
||
}
|
||
if esc == "u":
|
||
if i + 6 > n:
|
||
break
|
||
hex_part = src[i + 2 : i + 6]
|
||
try:
|
||
out.append(chr(int(hex_part, 16)))
|
||
except Exception:
|
||
pass
|
||
i += 6
|
||
continue
|
||
if esc in mapping:
|
||
out.append(mapping[esc])
|
||
i += 2
|
||
continue
|
||
out.append(esc)
|
||
i += 2
|
||
return "".join(out), i, False
|
||
|
||
@staticmethod
|
||
def _is_escaped(src: str, quote_index: int) -> bool:
|
||
backslashes = 0
|
||
i = quote_index - 1
|
||
while i >= 0 and src[i] == "\\":
|
||
backslashes += 1
|
||
i -= 1
|
||
return backslashes % 2 == 0
|
||
|
||
|
||
def _format_exc_raw(e: Exception) -> str:
|
||
"""统一输出最直接的异常原文(类型 + repr)。"""
|
||
return f"{type(e).__name__}: {e!r}"
|
||
|
||
|
||
def _chat_completions_stream_text(
|
||
*,
|
||
api_base: str,
|
||
api_key: str,
|
||
model_name: str,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
temperature: float,
|
||
max_tokens: int,
|
||
extra_payload: dict[str, Any],
|
||
connect_timeout_sec: int,
|
||
read_timeout_sec: int = 300,
|
||
on_content_delta: Optional[callable] = None,
|
||
) -> str:
|
||
"""
|
||
以 OpenAI-compat SSE 流式读取模型输出文本。
|
||
- connect timeout 保留,避免连接阶段长时间卡死
|
||
- read timeout 防止流式读取无限挂起(默认 300s)
|
||
"""
|
||
_logger.info(
|
||
"LLM 流式调用开始 | model=%s | temperature=%s | max_tokens=%s | timeout_connect=%s timeout_read=%s",
|
||
model_name, temperature, max_tokens, connect_timeout_sec, read_timeout_sec,
|
||
)
|
||
resp = requests.post(
|
||
f"{api_base}/chat/completions",
|
||
headers={
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={
|
||
"model": model_name,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
"response_format": {"type": "json_object"},
|
||
"stream": True,
|
||
**extra_payload,
|
||
},
|
||
stream=True,
|
||
timeout=(connect_timeout_sec, max(60, read_timeout_sec)),
|
||
)
|
||
|
||
if resp.status_code in (408, 429, 500, 502, 503, 504):
|
||
raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}")
|
||
if resp.status_code != 200:
|
||
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
|
||
|
||
resp.encoding = "utf-8"
|
||
chunks: list[str] = []
|
||
extractor = _ContentFieldStreamExtractor()
|
||
try:
|
||
for line in resp.iter_lines(decode_unicode=True):
|
||
if not line:
|
||
continue
|
||
s = line.strip()
|
||
if not s.startswith("data:"):
|
||
continue
|
||
payload = s[5:].strip()
|
||
if not payload or payload == "[DONE]":
|
||
break
|
||
try:
|
||
obj = json.loads(payload)
|
||
except Exception:
|
||
continue
|
||
choices = obj.get("choices")
|
||
if not isinstance(choices, list) or not choices:
|
||
continue
|
||
first = choices[0] if isinstance(choices[0], dict) else {}
|
||
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
|
||
content = delta.get("content")
|
||
if isinstance(content, str) and content:
|
||
chunks.append(content)
|
||
# print(content, end="", flush=True)
|
||
if on_content_delta:
|
||
delta_text, done_now = extractor.feed(content)
|
||
if delta_text:
|
||
try:
|
||
on_content_delta("delta", delta_text)
|
||
except Exception:
|
||
pass
|
||
if done_now:
|
||
try:
|
||
on_content_delta("finalizing", "")
|
||
except Exception:
|
||
pass
|
||
# 兼容部分实现把最终结果放在 message.content
|
||
message = first.get("message") if isinstance(first.get("message"), dict) else {}
|
||
msg_content = message.get("content")
|
||
if isinstance(msg_content, str) and msg_content:
|
||
chunks.append(msg_content)
|
||
# print(msg_content, end="", flush=True)
|
||
if on_content_delta:
|
||
delta_text, done_now = extractor.feed(msg_content)
|
||
if delta_text:
|
||
try:
|
||
on_content_delta("delta", delta_text)
|
||
except Exception:
|
||
pass
|
||
if done_now:
|
||
try:
|
||
on_content_delta("finalizing", "")
|
||
except Exception:
|
||
pass
|
||
except ChunkedEncodingError as e:
|
||
partial_text = "".join(chunks).strip()
|
||
# 若流提前结束但已收到完整 JSON,直接使用,避免无谓重试失败。
|
||
if partial_text:
|
||
try:
|
||
parse_json_object_from_text(partial_text)
|
||
return partial_text
|
||
except Exception:
|
||
pass
|
||
raise _RetryableLLMError(f"LLM 流中断: {_format_exc_raw(e)}") from e
|
||
|
||
text = "".join(chunks).strip()
|
||
if not text:
|
||
raise _RetryableLLMError("LLM 返回空内容")
|
||
print()
|
||
return text
|
||
|
||
|
||
def chat_completions_json(
|
||
*,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
temperature: float = 0.2,
|
||
max_tokens: int = 4096,
|
||
timeout_sec: int = 180,
|
||
on_content_delta: Optional[callable] = None,
|
||
log_context: str = "",
|
||
) -> dict[str, Any]:
|
||
"""
|
||
统一的 OpenAI-compat chat/completions 调用,强制返回 JSON object。
|
||
复用项目现有配置:LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME。
|
||
|
||
log_context: 调用来源标签(如章节编号),用于在 generation_trace.log 中区分各次生成调用。
|
||
"""
|
||
api_base = (settings.LLM_API_BASE or "").rstrip("/")
|
||
api_key = settings.LLM_API_KEY or ""
|
||
model_name = settings.LLM_MODEL_NAME or ""
|
||
|
||
if not api_base or not api_key or not model_name:
|
||
raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME")
|
||
|
||
ctx = log_context or "-"
|
||
_trace_logger.info(
|
||
"[输入] context=%s | model=%s | temperature=%s | max_tokens=%s\n"
|
||
"----- SYSTEM PROMPT -----\n%s\n"
|
||
"----- USER PROMPT -----\n%s\n"
|
||
"----- END INPUT -----",
|
||
ctx, model_name, temperature, max_tokens, system_prompt, user_prompt,
|
||
)
|
||
|
||
extra_payload: dict[str, Any] = {}
|
||
# SiliconFlow 的部分 Qwen 模型默认把输出写到 reasoning_content,导致 content 为空;
|
||
# 显式关闭 thinking,确保最终输出进入 content,避免下游解析失败。
|
||
if "siliconflow" in api_base.lower() and "qwen" in model_name.lower():
|
||
extra_payload["enable_thinking"] = False
|
||
|
||
final_timeout_sec = int(timeout_sec or 0)
|
||
if final_timeout_sec <= 0:
|
||
final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90)
|
||
retry_count = int(getattr(settings, "LLM_RETRY_COUNT", 2) or 2)
|
||
if retry_count < 1:
|
||
retry_count = 1
|
||
retry_backoff = float(getattr(settings, "LLM_RETRY_BACKOFF_SEC", 1.0) or 1.0)
|
||
retry_backoff_max = float(getattr(settings, "LLM_RETRY_BACKOFF_MAX_SEC", 12.0) or 12.0)
|
||
connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20)
|
||
if connect_timeout_sec <= 0:
|
||
connect_timeout_sec = 20
|
||
use_stream = True
|
||
|
||
_logger.info(
|
||
"chat_completions_json 调用 | model=%s | temperature=%s | max_tokens=%s | timeout=%s | retry=%s",
|
||
model_name, temperature, max_tokens, final_timeout_sec, retry_count,
|
||
)
|
||
|
||
with _llm_slots:
|
||
last_err: Optional[Exception] = None
|
||
for attempt in range(retry_count):
|
||
try:
|
||
if use_stream:
|
||
content = _chat_completions_stream_text(
|
||
api_base=api_base,
|
||
api_key=api_key,
|
||
model_name=model_name,
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
extra_payload=extra_payload,
|
||
connect_timeout_sec=connect_timeout_sec,
|
||
read_timeout_sec=final_timeout_sec,
|
||
on_content_delta=on_content_delta,
|
||
)
|
||
else:
|
||
# 分离连接超时与读超时:长生成阶段只应占用「读」时间,避免与连接握手混在一个上限里过早超时
|
||
read_timeout = max(int(connect_timeout_sec) + 5, int(final_timeout_sec))
|
||
resp = requests.post(
|
||
f"{api_base}/chat/completions",
|
||
headers={
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json={
|
||
"model": model_name,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
"response_format": {"type": "json_object"},
|
||
**extra_payload,
|
||
},
|
||
timeout=(connect_timeout_sec, read_timeout),
|
||
)
|
||
if resp.status_code in (408, 429, 500, 502, 503, 504):
|
||
raise _RetryableLLMError(
|
||
f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}"
|
||
)
|
||
if resp.status_code != 200:
|
||
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
|
||
|
||
data = resp.json()
|
||
content = (
|
||
(data.get("choices") or [{}])[0]
|
||
.get("message", {})
|
||
.get("content", "")
|
||
)
|
||
if not isinstance(content, str) or not content.strip():
|
||
raise _RetryableLLMError("LLM 返回空内容")
|
||
try:
|
||
obj = parse_json_object_from_text(content)
|
||
_logger.info(
|
||
"chat_completions_json 成功 | model=%s | attempt=%d/%d | content_len=%d | keys=%s",
|
||
model_name, attempt + 1, retry_count, len(content), list(obj.keys())[:8],
|
||
)
|
||
_trace_logger.info(
|
||
"[输出] context=%s | model=%s | attempt=%d/%d | output_len=%d\n"
|
||
"----- MODEL OUTPUT -----\n%s\n"
|
||
"----- END OUTPUT -----",
|
||
ctx, model_name, attempt + 1, retry_count, len(content), content,
|
||
)
|
||
return obj
|
||
except ValueError as e:
|
||
raise _RetryableLLMError(f"LLM JSON 解析失败: {e}") from e
|
||
except (
|
||
requests.ReadTimeout,
|
||
requests.ConnectTimeout,
|
||
requests.ConnectionError,
|
||
ChunkedEncodingError,
|
||
) as e:
|
||
last_err = e
|
||
if attempt >= retry_count - 1:
|
||
raise RuntimeError(
|
||
"LLM 请求超时/连接失败"
|
||
f"(已重试{retry_count}次,timeout={final_timeout_sec}s)"
|
||
f",endpoint={api_base}/chat/completions"
|
||
f",model={model_name}"
|
||
f",raw={_format_exc_raw(e)}"
|
||
) from e
|
||
sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max)
|
||
sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2))
|
||
time.sleep(sleep_sec)
|
||
except _RetryableLLMError as e:
|
||
last_err = e
|
||
if attempt >= retry_count - 1:
|
||
raise RuntimeError(
|
||
f"{e}(已重试{retry_count}次,timeout={final_timeout_sec}s)"
|
||
f",endpoint={api_base}/chat/completions"
|
||
f",model={model_name}"
|
||
f",raw={_format_exc_raw(e)}"
|
||
) from e
|
||
sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max)
|
||
sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2))
|
||
time.sleep(sleep_sec)
|
||
except RequestException as e:
|
||
resp = getattr(e, "response", None)
|
||
status = getattr(resp, "status_code", None)
|
||
body = ""
|
||
if resp is not None:
|
||
try:
|
||
body = (resp.text or "")[:800]
|
||
except Exception:
|
||
body = ""
|
||
raise RuntimeError(
|
||
"LLM 请求失败"
|
||
f",endpoint={api_base}/chat/completions"
|
||
f",model={model_name}"
|
||
f",status={status}"
|
||
+ (f",body={body}" if body else "")
|
||
+ f",raw={_format_exc_raw(e)}"
|
||
) from e
|
||
else:
|
||
raw = _format_exc_raw(last_err) if isinstance(last_err, Exception) else str(last_err)
|
||
raise RuntimeError(
|
||
"LLM 请求失败"
|
||
f",endpoint={api_base}/chat/completions"
|
||
f",model={model_name}"
|
||
f",raw={raw}"
|
||
)
|
||
|
||
|
||
def _repair_loose_json_object(s: str) -> str:
|
||
"""常见模型输出问题:尾随逗号(, 后紧跟 } 或 ])。"""
|
||
return re.sub(r",(\s*[}\]])", r"\1", s)
|
||
|
||
|
||
def _extract_balanced_json_prefix(s: str) -> str:
|
||
"""
|
||
提取以 `{` 开始的最长“可能完整”的 JSON 对象前缀。
|
||
会忽略字符串内的花括号,避免误判。
|
||
"""
|
||
start = s.find("{")
|
||
if start == -1:
|
||
return s
|
||
in_string = False
|
||
escaped = False
|
||
depth = 0
|
||
end_idx = -1
|
||
for i, ch in enumerate(s[start:], start=start):
|
||
if in_string:
|
||
if escaped:
|
||
escaped = False
|
||
elif ch == "\\":
|
||
escaped = True
|
||
elif ch == '"':
|
||
in_string = False
|
||
continue
|
||
if ch == '"':
|
||
in_string = True
|
||
elif ch == "{":
|
||
depth += 1
|
||
elif ch == "}":
|
||
depth -= 1
|
||
if depth == 0:
|
||
end_idx = i
|
||
break
|
||
if end_idx != -1:
|
||
return s[start : end_idx + 1]
|
||
return s[start:]
|
||
|
||
|
||
def _close_truncated_json_object(s: str) -> str:
|
||
"""
|
||
处理模型截断导致的 JSON 残缺:
|
||
- 若字符串未闭合,补一个 `"`
|
||
- 按栈补齐缺失的 `}` / `]`
|
||
"""
|
||
out: list[str] = []
|
||
stack: list[str] = []
|
||
in_string = False
|
||
escaped = False
|
||
|
||
for ch in s:
|
||
out.append(ch)
|
||
if in_string:
|
||
if escaped:
|
||
escaped = False
|
||
elif ch == "\\":
|
||
escaped = True
|
||
elif ch == '"':
|
||
in_string = False
|
||
continue
|
||
if ch == '"':
|
||
in_string = True
|
||
continue
|
||
if ch == "{":
|
||
stack.append("}")
|
||
elif ch == "[":
|
||
stack.append("]")
|
||
elif ch in ("}", "]"):
|
||
if stack and stack[-1] == ch:
|
||
stack.pop()
|
||
|
||
if in_string:
|
||
out.append('"')
|
||
while stack:
|
||
out.append(stack.pop())
|
||
return "".join(out)
|
||
|
||
|
||
def parse_json_object_from_text(text: str) -> dict[str, Any]:
|
||
"""从模型输出里提取并解析 { ... } JSON 对象。"""
|
||
s = (text or "").strip()
|
||
s = re.sub(r"```(?:json)?", "", s, flags=re.IGNORECASE).replace("```", "").strip()
|
||
start = s.find("{")
|
||
if start == -1:
|
||
raise ValueError("未找到 JSON 对象")
|
||
chunk = s[start:]
|
||
balanced_chunk = _extract_balanced_json_prefix(chunk)
|
||
decoder = json.JSONDecoder()
|
||
last_err: Optional[Exception] = None
|
||
for candidate in (
|
||
balanced_chunk,
|
||
_repair_loose_json_object(balanced_chunk),
|
||
_close_truncated_json_object(_repair_loose_json_object(balanced_chunk)),
|
||
_close_truncated_json_object(_repair_loose_json_object(chunk)),
|
||
):
|
||
try:
|
||
obj, _ = decoder.raw_decode(candidate)
|
||
if not isinstance(obj, dict):
|
||
raise ValueError("JSON 根节点不是对象(dict)")
|
||
return obj
|
||
except json.JSONDecodeError as e:
|
||
last_err = e
|
||
raise ValueError(f"JSON 解析失败:{last_err}") from last_err
|
||
|
||
|
||
def safe_get_str(v: Any) -> Optional[str]:
|
||
if v is None:
|
||
return None
|
||
s = str(v).strip()
|
||
return s if s else None
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# Agent 多轮对话 + Tool Calling(流式生成器)
|
||
# ------------------------------------------------------------------
|
||
|
||
def _iter_chat_stream_events(
|
||
*,
|
||
api_base: str,
|
||
api_key: str,
|
||
model_name: str,
|
||
messages: list[dict[str, Any]],
|
||
tools: list[dict[str, Any]] | None = None,
|
||
temperature: float = 0.3,
|
||
max_tokens: int = 4096,
|
||
extra_payload: dict[str, Any] | None = None,
|
||
connect_timeout_sec: int = 20,
|
||
read_timeout_sec: int = 300,
|
||
):
|
||
"""
|
||
流式调用 OpenAI-compat /chat/completions,逐步 yield 事件:
|
||
("delta", str) — 文本增量
|
||
("tool_calls", list) — 完整 tool_calls 列表 [{id, function:{name, arguments}}]
|
||
("done", dict) — 最终 usage 等元信息
|
||
"""
|
||
payload: dict[str, Any] = {
|
||
"model": model_name,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
"stream": True,
|
||
**(extra_payload or {}),
|
||
}
|
||
if tools:
|
||
payload["tools"] = tools
|
||
payload["tool_choice"] = "auto"
|
||
|
||
resp = requests.post(
|
||
f"{api_base}/chat/completions",
|
||
headers={
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
json=payload,
|
||
stream=True,
|
||
timeout=(connect_timeout_sec, max(60, read_timeout_sec)),
|
||
)
|
||
|
||
if resp.status_code in (408, 429, 500, 502, 503, 504):
|
||
raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}")
|
||
if resp.status_code != 200:
|
||
raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}")
|
||
|
||
resp.encoding = "utf-8"
|
||
content_parts: list[str] = []
|
||
tool_calls_map: dict[int, dict] = {}
|
||
|
||
for line in resp.iter_lines(decode_unicode=True):
|
||
if not line:
|
||
continue
|
||
s = line.strip()
|
||
if not s.startswith("data:"):
|
||
continue
|
||
data_str = s[5:].strip()
|
||
if not data_str or data_str == "[DONE]":
|
||
break
|
||
try:
|
||
obj = json.loads(data_str)
|
||
except Exception:
|
||
continue
|
||
choices = obj.get("choices")
|
||
if not isinstance(choices, list) or not choices:
|
||
continue
|
||
first = choices[0] if isinstance(choices[0], dict) else {}
|
||
delta = first.get("delta") if isinstance(first.get("delta"), dict) else {}
|
||
|
||
# 仅输出正文 content;忽略 reasoning_content,避免思考过程展示给用户
|
||
_ = delta.get("reasoning_content")
|
||
|
||
content = delta.get("content")
|
||
if isinstance(content, str) and content:
|
||
content_parts.append(content)
|
||
yield ("delta", content)
|
||
|
||
# tool_calls delta (streamed incrementally)
|
||
tc_deltas = delta.get("tool_calls")
|
||
if isinstance(tc_deltas, list):
|
||
for tc in tc_deltas:
|
||
if not isinstance(tc, dict):
|
||
continue
|
||
idx = tc.get("index", 0)
|
||
if idx not in tool_calls_map:
|
||
tool_calls_map[idx] = {
|
||
"id": tc.get("id", ""),
|
||
"type": "function",
|
||
"function": {"name": "", "arguments": ""},
|
||
}
|
||
entry = tool_calls_map[idx]
|
||
if tc.get("id"):
|
||
entry["id"] = tc["id"]
|
||
fn = tc.get("function") if isinstance(tc.get("function"), dict) else {}
|
||
if fn.get("name"):
|
||
entry["function"]["name"] += fn["name"]
|
||
if fn.get("arguments"):
|
||
entry["function"]["arguments"] += fn["arguments"]
|
||
|
||
# finish_reason
|
||
finish = first.get("finish_reason")
|
||
if finish == "tool_calls" and tool_calls_map:
|
||
ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())]
|
||
yield ("tool_calls", ordered)
|
||
tool_calls_map = {}
|
||
|
||
if tool_calls_map:
|
||
ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())]
|
||
yield ("tool_calls", ordered)
|
||
|
||
yield ("done", {"content": "".join(content_parts)})
|
||
|
||
|
||
def _default_disable_thinking_payload(model_name: str) -> dict[str, Any]:
|
||
"""Qwen 等推理模型:关闭 thinking,仅将最终答案写入 content。"""
|
||
if not model_name or "qwen" not in str(model_name).lower():
|
||
return {}
|
||
return {
|
||
"enable_thinking": False,
|
||
# vLLM / 部分 OpenAI 兼容网关使用 chat_template_kwargs
|
||
"chat_template_kwargs": {"enable_thinking": False},
|
||
}
|
||
|
||
|
||
def chat_completions_with_tools(
|
||
*,
|
||
messages: list[dict[str, Any]],
|
||
tools: list[dict[str, Any]] | None = None,
|
||
temperature: float = 0.3,
|
||
max_tokens: int = 4096,
|
||
timeout_sec: int = 180,
|
||
extra_payload: dict[str, Any] | None = None,
|
||
):
|
||
"""
|
||
Agent 用多轮对话 + tool calling。返回生成器,yield 事件元组。
|
||
调用方负责工具循环编排。
|
||
"""
|
||
api_base = (settings.LLM_API_BASE or "").rstrip("/")
|
||
api_key = settings.LLM_API_KEY or ""
|
||
model_name = settings.LLM_MODEL_NAME or ""
|
||
|
||
if not api_base or not api_key or not model_name:
|
||
raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME")
|
||
|
||
merged_extra: dict[str, Any] = dict(_default_disable_thinking_payload(model_name))
|
||
if extra_payload:
|
||
merged_extra.update(extra_payload)
|
||
|
||
connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20)
|
||
if connect_timeout_sec <= 0:
|
||
connect_timeout_sec = 20
|
||
|
||
final_timeout_sec = int(timeout_sec or 0)
|
||
if final_timeout_sec <= 0:
|
||
final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90)
|
||
|
||
with _llm_slots:
|
||
yield from _iter_chat_stream_events(
|
||
api_base=api_base,
|
||
api_key=api_key,
|
||
model_name=model_name,
|
||
messages=messages,
|
||
tools=tools,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
extra_payload=merged_extra or None,
|
||
connect_timeout_sec=connect_timeout_sec,
|
||
read_timeout_sec=final_timeout_sec,
|
||
)
|