from __future__ import annotations import json import logging import random import re import time import threading from typing import Any, Optional import requests from requests import RequestException from requests.exceptions import ChunkedEncodingError from config import settings _logger = logging.getLogger(__name__) # 生成全过程追踪:完整记录输入 prompt / 调用模型 / 模型输出,写入 logs/generation_trace.log _trace_logger = logging.getLogger("generation.trace") _LLM_MAX_CONCURRENCY = 5 _llm_slots = threading.BoundedSemaphore(_LLM_MAX_CONCURRENCY) class _RetryableLLMError(RuntimeError): """用于标记可安全重试的 LLM 调用异常。""" class _ContentFieldStreamExtractor: """从流式 JSON 文本中增量提取 content 字段的已解码正文。""" def __init__(self) -> None: self._raw = "" self._content_started = False self._content_done = False self._value_start = -1 self._consumed_pos = 0 def feed(self, chunk: str) -> tuple[str, bool]: if not chunk: return "", False self._raw += chunk emitted = "" done_now = False if not self._content_started: marker = '"content"' idx = self._raw.find(marker) if idx == -1: return "", False colon = self._raw.find(":", idx + len(marker)) if colon == -1: return "", False quote = self._raw.find('"', colon + 1) if quote == -1: return "", False self._content_started = True self._value_start = quote + 1 self._consumed_pos = self._value_start if self._content_started and not self._content_done: emitted, consumed_pos, done_now = self._decode_partial_json_string( self._raw, self._consumed_pos, ) self._consumed_pos = consumed_pos if done_now: self._content_done = True return emitted, done_now @staticmethod def _decode_partial_json_string(src: str, start: int) -> tuple[str, int, bool]: out: list[str] = [] i = start n = len(src) while i < n: ch = src[i] if ch == '"': if i == 0 or src[i - 1] != "\\" or _ContentFieldStreamExtractor._is_escaped(src, i): return "".join(out), i, True if ch != "\\": out.append(ch) i += 1 continue if i + 1 >= n: break esc = src[i + 1] mapping = { '"': '"', "\\": "\\", "/": "/", "b": "\b", "f": "\f", "n": "\n", "r": "\r", "t": "\t", } if esc == "u": if i + 6 > n: break hex_part = src[i + 2 : i + 6] try: out.append(chr(int(hex_part, 16))) except Exception: pass i += 6 continue if esc in mapping: out.append(mapping[esc]) i += 2 continue out.append(esc) i += 2 return "".join(out), i, False @staticmethod def _is_escaped(src: str, quote_index: int) -> bool: backslashes = 0 i = quote_index - 1 while i >= 0 and src[i] == "\\": backslashes += 1 i -= 1 return backslashes % 2 == 0 def _format_exc_raw(e: Exception) -> str: """统一输出最直接的异常原文(类型 + repr)。""" return f"{type(e).__name__}: {e!r}" def _chat_completions_stream_text( *, api_base: str, api_key: str, model_name: str, system_prompt: str, user_prompt: str, temperature: float, max_tokens: int, extra_payload: dict[str, Any], connect_timeout_sec: int, read_timeout_sec: int = 300, on_content_delta: Optional[callable] = None, ) -> str: """ 以 OpenAI-compat SSE 流式读取模型输出文本。 - connect timeout 保留,避免连接阶段长时间卡死 - read timeout 防止流式读取无限挂起(默认 300s) """ _logger.info( "LLM 流式调用开始 | model=%s | temperature=%s | max_tokens=%s | timeout_connect=%s timeout_read=%s", model_name, temperature, max_tokens, connect_timeout_sec, read_timeout_sec, ) resp = requests.post( f"{api_base}/chat/completions", headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, json={ "model": model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "temperature": temperature, "max_tokens": max_tokens, "response_format": {"type": "json_object"}, "stream": True, **extra_payload, }, stream=True, timeout=(connect_timeout_sec, max(60, read_timeout_sec)), ) if resp.status_code in (408, 429, 500, 502, 503, 504): raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}") if resp.status_code != 200: raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}") resp.encoding = "utf-8" chunks: list[str] = [] extractor = _ContentFieldStreamExtractor() try: for line in resp.iter_lines(decode_unicode=True): if not line: continue s = line.strip() if not s.startswith("data:"): continue payload = s[5:].strip() if not payload or payload == "[DONE]": break try: obj = json.loads(payload) except Exception: continue choices = obj.get("choices") if not isinstance(choices, list) or not choices: continue first = choices[0] if isinstance(choices[0], dict) else {} delta = first.get("delta") if isinstance(first.get("delta"), dict) else {} content = delta.get("content") if isinstance(content, str) and content: chunks.append(content) # print(content, end="", flush=True) if on_content_delta: delta_text, done_now = extractor.feed(content) if delta_text: try: on_content_delta("delta", delta_text) except Exception: pass if done_now: try: on_content_delta("finalizing", "") except Exception: pass # 兼容部分实现把最终结果放在 message.content message = first.get("message") if isinstance(first.get("message"), dict) else {} msg_content = message.get("content") if isinstance(msg_content, str) and msg_content: chunks.append(msg_content) # print(msg_content, end="", flush=True) if on_content_delta: delta_text, done_now = extractor.feed(msg_content) if delta_text: try: on_content_delta("delta", delta_text) except Exception: pass if done_now: try: on_content_delta("finalizing", "") except Exception: pass except ChunkedEncodingError as e: partial_text = "".join(chunks).strip() # 若流提前结束但已收到完整 JSON,直接使用,避免无谓重试失败。 if partial_text: try: parse_json_object_from_text(partial_text) return partial_text except Exception: pass raise _RetryableLLMError(f"LLM 流中断: {_format_exc_raw(e)}") from e text = "".join(chunks).strip() if not text: raise _RetryableLLMError("LLM 返回空内容") print() return text def chat_completions_json( *, system_prompt: str, user_prompt: str, temperature: float = 0.2, max_tokens: int = 4096, timeout_sec: int = 180, on_content_delta: Optional[callable] = None, log_context: str = "", ) -> dict[str, Any]: """ 统一的 OpenAI-compat chat/completions 调用,强制返回 JSON object。 复用项目现有配置:LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME。 log_context: 调用来源标签(如章节编号),用于在 generation_trace.log 中区分各次生成调用。 """ api_base = (settings.LLM_API_BASE or "").rstrip("/") api_key = settings.LLM_API_KEY or "" model_name = settings.LLM_MODEL_NAME or "" if not api_base or not api_key or not model_name: raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME") ctx = log_context or "-" _trace_logger.info( "[输入] context=%s | model=%s | temperature=%s | max_tokens=%s\n" "----- SYSTEM PROMPT -----\n%s\n" "----- USER PROMPT -----\n%s\n" "----- END INPUT -----", ctx, model_name, temperature, max_tokens, system_prompt, user_prompt, ) extra_payload: dict[str, Any] = {} # SiliconFlow 的部分 Qwen 模型默认把输出写到 reasoning_content,导致 content 为空; # 显式关闭 thinking,确保最终输出进入 content,避免下游解析失败。 if "siliconflow" in api_base.lower() and "qwen" in model_name.lower(): extra_payload["enable_thinking"] = False final_timeout_sec = int(timeout_sec or 0) if final_timeout_sec <= 0: final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90) retry_count = int(getattr(settings, "LLM_RETRY_COUNT", 2) or 2) if retry_count < 1: retry_count = 1 retry_backoff = float(getattr(settings, "LLM_RETRY_BACKOFF_SEC", 1.0) or 1.0) retry_backoff_max = float(getattr(settings, "LLM_RETRY_BACKOFF_MAX_SEC", 12.0) or 12.0) connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20) if connect_timeout_sec <= 0: connect_timeout_sec = 20 use_stream = True _logger.info( "chat_completions_json 调用 | model=%s | temperature=%s | max_tokens=%s | timeout=%s | retry=%s", model_name, temperature, max_tokens, final_timeout_sec, retry_count, ) with _llm_slots: last_err: Optional[Exception] = None for attempt in range(retry_count): try: if use_stream: content = _chat_completions_stream_text( api_base=api_base, api_key=api_key, model_name=model_name, system_prompt=system_prompt, user_prompt=user_prompt, temperature=temperature, max_tokens=max_tokens, extra_payload=extra_payload, connect_timeout_sec=connect_timeout_sec, read_timeout_sec=final_timeout_sec, on_content_delta=on_content_delta, ) else: # 分离连接超时与读超时:长生成阶段只应占用「读」时间,避免与连接握手混在一个上限里过早超时 read_timeout = max(int(connect_timeout_sec) + 5, int(final_timeout_sec)) resp = requests.post( f"{api_base}/chat/completions", headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, json={ "model": model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "temperature": temperature, "max_tokens": max_tokens, "response_format": {"type": "json_object"}, **extra_payload, }, timeout=(connect_timeout_sec, read_timeout), ) if resp.status_code in (408, 429, 500, 502, 503, 504): raise _RetryableLLMError( f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}" ) if resp.status_code != 200: raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}") data = resp.json() content = ( (data.get("choices") or [{}])[0] .get("message", {}) .get("content", "") ) if not isinstance(content, str) or not content.strip(): raise _RetryableLLMError("LLM 返回空内容") try: obj = parse_json_object_from_text(content) _logger.info( "chat_completions_json 成功 | model=%s | attempt=%d/%d | content_len=%d | keys=%s", model_name, attempt + 1, retry_count, len(content), list(obj.keys())[:8], ) _trace_logger.info( "[输出] context=%s | model=%s | attempt=%d/%d | output_len=%d\n" "----- MODEL OUTPUT -----\n%s\n" "----- END OUTPUT -----", ctx, model_name, attempt + 1, retry_count, len(content), content, ) return obj except ValueError as e: raise _RetryableLLMError(f"LLM JSON 解析失败: {e}") from e except ( requests.ReadTimeout, requests.ConnectTimeout, requests.ConnectionError, ChunkedEncodingError, ) as e: last_err = e if attempt >= retry_count - 1: raise RuntimeError( "LLM 请求超时/连接失败" f"(已重试{retry_count}次,timeout={final_timeout_sec}s)" f",endpoint={api_base}/chat/completions" f",model={model_name}" f",raw={_format_exc_raw(e)}" ) from e sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max) sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2)) time.sleep(sleep_sec) except _RetryableLLMError as e: last_err = e if attempt >= retry_count - 1: raise RuntimeError( f"{e}(已重试{retry_count}次,timeout={final_timeout_sec}s)" f",endpoint={api_base}/chat/completions" f",model={model_name}" f",raw={_format_exc_raw(e)}" ) from e sleep_sec = min(retry_backoff * (2 ** attempt), retry_backoff_max) sleep_sec += random.uniform(0, min(0.5, sleep_sec * 0.2)) time.sleep(sleep_sec) except RequestException as e: resp = getattr(e, "response", None) status = getattr(resp, "status_code", None) body = "" if resp is not None: try: body = (resp.text or "")[:800] except Exception: body = "" raise RuntimeError( "LLM 请求失败" f",endpoint={api_base}/chat/completions" f",model={model_name}" f",status={status}" + (f",body={body}" if body else "") + f",raw={_format_exc_raw(e)}" ) from e else: raw = _format_exc_raw(last_err) if isinstance(last_err, Exception) else str(last_err) raise RuntimeError( "LLM 请求失败" f",endpoint={api_base}/chat/completions" f",model={model_name}" f",raw={raw}" ) def _repair_loose_json_object(s: str) -> str: """常见模型输出问题:尾随逗号(, 后紧跟 } 或 ])。""" return re.sub(r",(\s*[}\]])", r"\1", s) def _extract_balanced_json_prefix(s: str) -> str: """ 提取以 `{` 开始的最长“可能完整”的 JSON 对象前缀。 会忽略字符串内的花括号,避免误判。 """ start = s.find("{") if start == -1: return s in_string = False escaped = False depth = 0 end_idx = -1 for i, ch in enumerate(s[start:], start=start): if in_string: if escaped: escaped = False elif ch == "\\": escaped = True elif ch == '"': in_string = False continue if ch == '"': in_string = True elif ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: end_idx = i break if end_idx != -1: return s[start : end_idx + 1] return s[start:] def _close_truncated_json_object(s: str) -> str: """ 处理模型截断导致的 JSON 残缺: - 若字符串未闭合,补一个 `"` - 按栈补齐缺失的 `}` / `]` """ out: list[str] = [] stack: list[str] = [] in_string = False escaped = False for ch in s: out.append(ch) if in_string: if escaped: escaped = False elif ch == "\\": escaped = True elif ch == '"': in_string = False continue if ch == '"': in_string = True continue if ch == "{": stack.append("}") elif ch == "[": stack.append("]") elif ch in ("}", "]"): if stack and stack[-1] == ch: stack.pop() if in_string: out.append('"') while stack: out.append(stack.pop()) return "".join(out) def parse_json_object_from_text(text: str) -> dict[str, Any]: """从模型输出里提取并解析 { ... } JSON 对象。""" s = (text or "").strip() s = re.sub(r"```(?:json)?", "", s, flags=re.IGNORECASE).replace("```", "").strip() start = s.find("{") if start == -1: raise ValueError("未找到 JSON 对象") chunk = s[start:] balanced_chunk = _extract_balanced_json_prefix(chunk) decoder = json.JSONDecoder() last_err: Optional[Exception] = None for candidate in ( balanced_chunk, _repair_loose_json_object(balanced_chunk), _close_truncated_json_object(_repair_loose_json_object(balanced_chunk)), _close_truncated_json_object(_repair_loose_json_object(chunk)), ): try: obj, _ = decoder.raw_decode(candidate) if not isinstance(obj, dict): raise ValueError("JSON 根节点不是对象(dict)") return obj except json.JSONDecodeError as e: last_err = e raise ValueError(f"JSON 解析失败:{last_err}") from last_err def safe_get_str(v: Any) -> Optional[str]: if v is None: return None s = str(v).strip() return s if s else None # ------------------------------------------------------------------ # Agent 多轮对话 + Tool Calling(流式生成器) # ------------------------------------------------------------------ def _iter_chat_stream_events( *, api_base: str, api_key: str, model_name: str, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, temperature: float = 0.3, max_tokens: int = 4096, extra_payload: dict[str, Any] | None = None, connect_timeout_sec: int = 20, read_timeout_sec: int = 300, ): """ 流式调用 OpenAI-compat /chat/completions,逐步 yield 事件: ("delta", str) — 文本增量 ("tool_calls", list) — 完整 tool_calls 列表 [{id, function:{name, arguments}}] ("done", dict) — 最终 usage 等元信息 """ payload: dict[str, Any] = { "model": model_name, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True, **(extra_payload or {}), } if tools: payload["tools"] = tools payload["tool_choice"] = "auto" resp = requests.post( f"{api_base}/chat/completions", headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, json=payload, stream=True, timeout=(connect_timeout_sec, max(60, read_timeout_sec)), ) if resp.status_code in (408, 429, 500, 502, 503, 504): raise _RetryableLLMError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:300]}") if resp.status_code != 200: raise RuntimeError(f"LLM HTTP {resp.status_code}: {(resp.text or '')[:800]}") resp.encoding = "utf-8" content_parts: list[str] = [] tool_calls_map: dict[int, dict] = {} for line in resp.iter_lines(decode_unicode=True): if not line: continue s = line.strip() if not s.startswith("data:"): continue data_str = s[5:].strip() if not data_str or data_str == "[DONE]": break try: obj = json.loads(data_str) except Exception: continue choices = obj.get("choices") if not isinstance(choices, list) or not choices: continue first = choices[0] if isinstance(choices[0], dict) else {} delta = first.get("delta") if isinstance(first.get("delta"), dict) else {} # 仅输出正文 content;忽略 reasoning_content,避免思考过程展示给用户 _ = delta.get("reasoning_content") content = delta.get("content") if isinstance(content, str) and content: content_parts.append(content) yield ("delta", content) # tool_calls delta (streamed incrementally) tc_deltas = delta.get("tool_calls") if isinstance(tc_deltas, list): for tc in tc_deltas: if not isinstance(tc, dict): continue idx = tc.get("index", 0) if idx not in tool_calls_map: tool_calls_map[idx] = { "id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}, } entry = tool_calls_map[idx] if tc.get("id"): entry["id"] = tc["id"] fn = tc.get("function") if isinstance(tc.get("function"), dict) else {} if fn.get("name"): entry["function"]["name"] += fn["name"] if fn.get("arguments"): entry["function"]["arguments"] += fn["arguments"] # finish_reason finish = first.get("finish_reason") if finish == "tool_calls" and tool_calls_map: ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())] yield ("tool_calls", ordered) tool_calls_map = {} if tool_calls_map: ordered = [tool_calls_map[k] for k in sorted(tool_calls_map.keys())] yield ("tool_calls", ordered) yield ("done", {"content": "".join(content_parts)}) def _default_disable_thinking_payload(model_name: str) -> dict[str, Any]: """Qwen 等推理模型:关闭 thinking,仅将最终答案写入 content。""" if not model_name or "qwen" not in str(model_name).lower(): return {} return { "enable_thinking": False, # vLLM / 部分 OpenAI 兼容网关使用 chat_template_kwargs "chat_template_kwargs": {"enable_thinking": False}, } def chat_completions_with_tools( *, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, temperature: float = 0.3, max_tokens: int = 4096, timeout_sec: int = 180, extra_payload: dict[str, Any] | None = None, ): """ Agent 用多轮对话 + tool calling。返回生成器,yield 事件元组。 调用方负责工具循环编排。 """ api_base = (settings.LLM_API_BASE or "").rstrip("/") api_key = settings.LLM_API_KEY or "" model_name = settings.LLM_MODEL_NAME or "" if not api_base or not api_key or not model_name: raise RuntimeError("LLM 未配置:请设置 LLM_API_BASE/LLM_API_KEY/LLM_MODEL_NAME") merged_extra: dict[str, Any] = dict(_default_disable_thinking_payload(model_name)) if extra_payload: merged_extra.update(extra_payload) connect_timeout_sec = int(getattr(settings, "LLM_CONNECT_TIMEOUT_SEC", 20) or 20) if connect_timeout_sec <= 0: connect_timeout_sec = 20 final_timeout_sec = int(timeout_sec or 0) if final_timeout_sec <= 0: final_timeout_sec = int(getattr(settings, "LLM_HTTP_TIMEOUT_SEC", 90) or 90) with _llm_slots: yield from _iter_chat_stream_events( api_base=api_base, api_key=api_key, model_name=model_name, messages=messages, tools=tools, temperature=temperature, max_tokens=max_tokens, extra_payload=merged_extra or None, connect_timeout_sec=connect_timeout_sec, read_timeout_sec=final_timeout_sec, )