#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ function/vector_store.py 向量库模块 - 与 kb_service 项目集成 已修改:drop_old 全部 = False,不会删除已有集合 ✅ 已修复 413 超长 token 问题(语义友好版) """ import re import json import logging from typing import Dict, List, Optional, Tuple from pathlib import Path from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_openai import OpenAIEmbeddings from langchain_milvus import Milvus, BM25BuiltInFunction from pymilvus import MilvusClient, connections from config import settings logger = logging.getLogger(__name__) # ============================================================================ # 配置 # ============================================================================ COLLECTION_NAME = "eval_report" EMBEDDING_API_BASE = settings.EMBEDDING_API_BASE EMBEDDING_API_KEY = settings.EMBEDDING_API_KEY MILVUS_DB_URL = settings.MILVUS_DB_URL CONSISTENCY_LEVEL = "Bounded" AUTO_ID = True METRIC_TYPE = "COSINE" INDEX_TYPE = "AUTOINDEX" SPARSE_METRIC_TYPE = "BM25" SPARSE_INDEX_TYPE = "SPARSE_INVERTED_INDEX" def _embedding_batch_limits() -> tuple[int, int, int]: max_docs = max(1, int(getattr(settings, "EMBEDDING_BATCH_MAX_DOCS", 4) or 4)) max_chars = max(512, int(getattr(settings, "EMBEDDING_BATCH_MAX_CHARS", 12000) or 12000)) max_chunk = max(512, int(getattr(settings, "EMBEDDING_MAX_CHUNK_CHARS", 4000) or 4000)) return max_docs, max_chars, max_chunk def _is_embedding_backend_oom(exc: BaseException) -> bool: msg = str(exc).lower() return ( "out of memory" in msg or "npu out of memory" in msg or "cuda out of memory" in msg or "error code: 424" in msg or "'code': 424" in msg ) def _add_documents_batch_with_retry(vs: Milvus, batch: List[Document]) -> List[str]: """写入一批文档;远端 embedding OOM 时自动拆半重试。""" if not batch: return [] try: return list(vs.add_documents(batch)) except Exception as e: if not _is_embedding_backend_oom(e) or len(batch) <= 1: raise mid = max(1, len(batch) // 2) logger.warning( "embedding 批次 OOM,拆分为 %s + %s 重试", mid, len(batch) - mid, ) ids: List[str] = [] ids.extend(_add_documents_batch_with_retry(vs, batch[:mid])) ids.extend(_add_documents_batch_with_retry(vs, batch[mid:])) return ids def _register_milvus_client_for_orm(client: MilvusClient) -> None: """pymilvus 2.6+ MilvusClient uses ConnectionManager; ORM Collection still resolves pymilvus.orm.connections by client._using. langchain-milvus touches Collection during Milvus.__init__, so register before constructing Milvus (bootstrap client).""" alias = client._using if connections.has_connection(alias): return cfg = client._config connections._alias_handlers[alias] = client._handler connections._alias_config[alias] = { "address": cfg.address, "user": "", "db_name": cfg.db_name or "default", } # ============================================================================ # VectorStore 类(已全部改为 drop_old=False) # ============================================================================ class VectorStore: def __init__( self, collection_name: str = COLLECTION_NAME, drop_old: bool = False, chunk_size: int = 500, chunk_overlap: int = 50 ): self.collection_name = collection_name self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self._drop_old = drop_old self._milvus = None def _get_embeddings(self): return OpenAIEmbeddings( base_url=EMBEDDING_API_BASE, api_key=EMBEDDING_API_KEY, model="bge-m3", check_embedding_ctx_length=False, ) def _get_milvus(self, drop_old: bool = False) -> Milvus: logger.info("【VectorStore】初始化 Milvus 混合向量存储(dense + sparse)") if self._milvus is not None and not drop_old: logger.info("【VectorStore】复用已有 Milvus 实例") return self._milvus if not MILVUS_DB_URL: raise ValueError("MILVUS_DB_URL 未配置,请在 .env 中设置") embeddings = self._get_embeddings() logger.info("【VectorStore】Embedding 模型 bge-m3 初始化完成") try: # 与 langchain 内 MilvusClient 共享 ConnectionManager,先注册 ORM alias,否则 __init__ 内访问 Collection 会报错 _register_milvus_client_for_orm(MilvusClient(uri=MILVUS_DB_URL)) self._milvus = Milvus( embedding_function=embeddings, builtin_function=BM25BuiltInFunction(), vector_field=["dense", "sparse"], connection_args={"uri": MILVUS_DB_URL}, collection_name=self.collection_name, consistency_level=CONSISTENCY_LEVEL, auto_id=AUTO_ID, drop_old=False, index_params=[ {"metric_type": METRIC_TYPE, "index_type": INDEX_TYPE}, {"metric_type": SPARSE_METRIC_TYPE, "index_type": SPARSE_INDEX_TYPE}, ], ) _register_milvus_client_for_orm(self._milvus.client) logger.info("✅ Milvus 混合向量存储初始化成功") except Exception as e: logger.error(f"❌ Milvus 初始化失败: {str(e)}", exc_info=True) raise return self._milvus # ======================================================================== # ✅ 修复版 add_documents:语义友好,不破坏段落,不触发413 # ======================================================================== def add_documents(self, documents: List[Document]) -> List[str]: if not documents: logger.info("【add_documents】无文档可写入") return [] max_docs_per_batch, max_chars_per_batch, max_chunk_chars = _embedding_batch_limits() # ---------------------- 语义安全切分(只修问题,不破坏结构)---------------------- # 只处理【真的超长】的段落,在句子/段落边界分割,绝不乱切 safe_splitter = RecursiveCharacterTextSplitter( chunk_size=max_chunk_chars, chunk_overlap=min(200, max(0, max_chunk_chars // 20)), separators=["\n\n", "\n", "。", "!", "?", ";", ":", ","] ) safe_documents = [] for doc in documents: # 超过限制才切分 if len(doc.page_content) > max_chunk_chars: chunks = safe_splitter.split_text(doc.page_content) for chunk in chunks: if chunk.strip(): safe_documents.append(Document( page_content=chunk, metadata=doc.metadata.copy() )) else: safe_documents.append(doc) # -------------------------------------------------------------------------------- # Milvus 现有集合要求部分 metadata 字段必填;历史调用方未必都传这些字段,这里统一兜底补齐。 for idx, doc in enumerate(safe_documents): metadata = doc.metadata or {} if not metadata.get("doc_id"): project_uuid = metadata.get("project_uuid") or "unknown_project" heading = metadata.get("heading") or "chunk" metadata["doc_id"] = f"{project_uuid}:{heading}:{idx}" if "original_title" not in metadata: metadata["original_title"] = metadata.get("heading") or "" if "path" not in metadata: metadata["path"] = "" if "project_uuid" not in metadata: metadata["project_uuid"] = "unknown_project" doc.metadata = metadata logger.info(f"【add_documents】预处理后准备写入 {len(safe_documents)} 条文档") vs = self._get_milvus(drop_old=self._drop_old) self._drop_old = False ids = [] current_batch: List[Document] = [] current_batch_chars = 0 batch_num = 1 def _flush_batch() -> None: nonlocal current_batch, current_batch_chars, batch_num if not current_batch: return logger.info( "【add_documents】写入批次 %s,数量:%s,约 %s 字符", batch_num, len(current_batch), current_batch_chars, ) try: res = _add_documents_batch_with_retry(vs, current_batch) ids.extend(res) logger.info("✅ 批次写入成功,返回 ID 数:%s", len(res)) except Exception as e: logger.error("❌ 批次写入失败: %s", e, exc_info=True) batch_num += 1 current_batch = [] current_batch_chars = 0 for doc in safe_documents: doc_chars = len(doc.page_content or "") would_exceed_docs = bool(current_batch) and len(current_batch) >= max_docs_per_batch would_exceed_chars = bool(current_batch) and ( current_batch_chars + doc_chars > max_chars_per_batch ) if would_exceed_docs or would_exceed_chars: _flush_batch() current_batch.append(doc) current_batch_chars += doc_chars _flush_batch() logger.info(f"【add_documents】全部完成,总写入 ID 数:{len(ids)}") return ids def similarity_search_with_score( self, query: str, k: int = 10, filter: Optional[str] = None ) -> List[Tuple[Document, float]]: vs = self._get_milvus(drop_old=False) query = query[:5000] if filter: return vs.similarity_search_with_score(query, k=k, filter=filter) return vs.similarity_search_with_score(query, k=k) def similarity_search_dense_filtered( self, query: str, k: int, filter_expr: str, ) -> List[Tuple[Document, float]]: """ 使用 dense 向量 ANN + Milvus 标量过滤检索。 hybrid(dense+sparse)集合上 langchain_milvus 的 filter 可能不生效,抽取侧召回用此路径保证 doc_id 隔离。 """ from pymilvus import MilvusClient q = (query or "")[:5000] if not q.strip(): return [] emb = self._get_embeddings().embed_query(q) client = MilvusClient(uri=MILVUS_DB_URL) try: raw = client.search( collection_name=self.collection_name, data=[emb], anns_field="dense", limit=max(1, int(k)), filter=filter_expr, output_fields=[ "text", "heading", "heading_level", "doc_id", "project_uuid", "original_title", "path", ], ) finally: client.close() hits = raw[0] if raw else [] out: List[Tuple[Document, float]] = [] for hit in hits: ent = hit.get("entity") or {} doc = Document( page_content=str(ent.get("text") or ""), metadata={ "heading": ent.get("heading"), "heading_level": ent.get("heading_level"), "doc_id": ent.get("doc_id"), "project_uuid": ent.get("project_uuid"), "original_title": ent.get("original_title"), "path": ent.get("path"), }, ) dist = hit.get("distance") try: score = float(dist) if dist is not None else 0.0 except (TypeError, ValueError): score = 0.0 out.append((doc, score)) return out def delete_by_filter(self, filter_expr: str) -> int: try: from pymilvus import MilvusClient client = MilvusClient(uri=MILVUS_DB_URL) if not client.has_collection(self.collection_name): return 0 # 某些集合主键字段名不叫 id(例如 langchain-milvus 可能使用自定义 PK/auto_id)。 # 先从集合描述里找主键字段,再用于 query 计数。 pk_field = None describe = client.describe_collection(self.collection_name) for f in describe.get("fields", []) or []: # 兼容不同返回结构:is_primary / isPrimary / primary if f.get("is_primary") or f.get("isPrimary") or f.get("primary"): pk_field = f.get("name") break count = 0 try: if pk_field: res = client.query( self.collection_name, filter=filter_expr, output_fields=[pk_field], ) count = len(res) else: # 找不到主键字段名时也不阻断删除 count = 0 except Exception: # 仅计数失败不影响删除 count = 0 client.delete(self.collection_name, filter=filter_expr) client.close() return count except Exception as e: logger.error(f"删除失败: {e}") return 0 # ============================================================================ # Markdown 拆分 # ============================================================================ def split_markdown(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> List[str]: if not text: return [] splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, separators=["\n\n", "。", "?", "!", "\n", ";", ":", ","] ) return splitter.split_text(text) def split_markdown_by_headings(content: str, chunk_size=300, chunk_overlap=40) -> List[Document]: if not content: return [] docs = [] lines = content.split("\n") current_heading = "" current_level = 0 current_lines = [] def flush(): nonlocal current_lines, current_heading, current_level txt = "\n".join(current_lines).strip() if txt: docs.append(Document( page_content=txt, metadata={"heading": current_heading, "heading_level": current_level} )) current_lines = [] for line in lines: line = line.rstrip() m = re.match(r"^(#{1,6})\s+(.+)$", line) if m: flush() current_level = len(m.group(1)) current_heading = m.group(2).strip() else: current_lines.append(line) flush() if not docs: chunks = split_markdown(content, chunk_size, chunk_overlap) for i, c in enumerate(chunks): docs.append( Document( page_content=c, metadata={"chunk_index": i, "heading": "", "heading_level": 0}, ) ) return docs def process_document_to_vector_store( doc_id: str, title: str, content: str, path: str, project_uuid: str, collection_name=COLLECTION_NAME ) -> bool: try: vs = VectorStore(collection_name=collection_name, drop_old=False) docs = split_markdown_by_headings(content) for d in docs: d.metadata["doc_id"] = doc_id d.metadata["original_title"] = title d.metadata["path"] = path d.metadata["project_uuid"] = project_uuid vs.add_documents(docs) return True except Exception as e: logger.error(f"处理文档失败: {e}") return False # ============================================================================ # 数据预处理 # ============================================================================ INPUT_FILE = "data/articles.jsonl" OUTPUT_CHUNK_FILE = "data/processed/eval_chunks.jsonl" def load_jsonl(filename: str, encoding="utf-8"): with open(filename, encoding=encoding) as f: for line in f: if line.strip(): yield json.loads(line) def write_jsonl(data, filename, append=False, ensure_ascii=False): mode = "a" if append else "w" with open(filename, mode, encoding="utf-8") as f: for item in data: f.write(json.dumps(item, ensure_ascii=ensure_ascii) + "\n") def clean_text(text: str) -> str: if not isinstance(text, str): return "" text = re.sub(r"[\x00-\x09\x0B-\x1F\x7F]", "", text) text = re.sub(r"[\u200b-\u200f\u2028\u2029]", "", text) text = re.sub(r"[:’“”•…–—]", "", text) text = re.sub(r"<[^>]+>", "\n", text) text = re.sub(r"\n+", "\n", text) text = re.sub(r" +", " ", text) text = re.sub(r"^[。,?!;:]", "", text) text = re.sub(r'[^\u4e00-\u9fff_a-zA-Z0-9\s,。!?;:、()《》【】""''·!@#$%^&*()_+=[]{}|;:\'",./<>?-]', "", text) return text.strip() def concat_metadata_to_content(title: str, content: str, metadata: dict): parts = [ f"标题:{title}", f"发布时间:{metadata.get('publish_time')}", f"作者:{metadata.get('author')}", f"来源:{metadata.get('source')}", ] parts = [p for p in parts if p.split(":")[-1]] return " | ".join(parts) + "\n---\n" + content.strip() def process_all_documents(input_file, output_file, chunk_size=500, overlap=50): docs = load_jsonl(input_file) splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap, separators=["\n\n", "。", "?", "!", "\n", ";", ":", ","]) all_chunks = [] num_docs = 0 for doc in docs: num_docs +=1 content = clean_text(doc["content"]) chunks = splitter.split_text(content) for i, chunk in chunks: clean_c = clean_text(chunk) if len(clean_c) <10: continue all_chunks.append({ "id": f"{doc['id']}_chunk_{i}", "doc_id": doc["id"], "title": doc["title"], "content": concat_metadata_to_content(doc["title"], clean_c, doc.get("metadata",{})), "chunk_index": i, "url": doc.get("metadata",{}).get("url","") }) write_jsonl(all_chunks, output_file) return {"num_docs":num_docs, "num_chunks":len(all_chunks)} def load_chunk_jsonl(path): res = [] with open(path, encoding="utf-8") as f: for line in f: if line.strip(): res.append(json.loads(line)) return res def build_index(data, vs: VectorStore): docs: List[Document] = [] for row in data: c = row.pop("content", "").strip() if len(c) < 10: continue docs.append(Document(page_content=c, metadata=row)) if docs: vs.add_documents(docs) def get_vector_store(drop_old=False): vs = VectorStore(collection_name=COLLECTION_NAME, drop_old=drop_old) return vs._get_milvus(drop_old=drop_old) def search_eval(query, top_k=10): from time import time vs = VectorStore(drop_old=False) st = time() results = vs.similarity_search_with_score(query, k=top_k) print(f"检索耗时: {time()-st:.2f}s") return results # ============================================================================ # 运行入口 # ============================================================================ if __name__ == "__main__": logger.info("="*60) logger.info("【Milvus 混合向量索引构建启动】dense + sparse(BM25)") logger.info("="*60) process_all_documents(INPUT_FILE, OUTPUT_CHUNK_FILE) logger.info("✅ 文本分块处理完成") chunk_data = load_chunk_jsonl(OUTPUT_CHUNK_FILE) logger.info(f"✅ 加载分块数据:{len(chunk_data)} 条") vs = VectorStore(drop_old=False) build_index(chunk_data, vs) logger.info("✅ 索引构建完成(增量写入)") res = search_eval("测试检索内容") logger.info(f"✅ 检索完成,命中数量:{len(res)}") for doc, score in res: logger.info(f"score={score:.4f} | content={doc.page_content[:80]}...") logger.info("="*60) logger.info("【全部执行完成】")