report_generation/function/vector_store.py
xxy aa98ea2623 @
Initial commit

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@
2026-06-05 18:45:29 +08:00

551 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
function/vector_store.py
向量库模块 - 与 kb_service 项目集成
已修改drop_old 全部 = False不会删除已有集合
✅ 已修复 413 超长 token 问题(语义友好版)
"""
import re
import json
import logging
from typing import Dict, List, Optional, Tuple
from pathlib import Path
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_milvus import Milvus, BM25BuiltInFunction
from pymilvus import MilvusClient, connections
from config import settings
logger = logging.getLogger(__name__)
# ============================================================================
# 配置
# ============================================================================
COLLECTION_NAME = "eval_report"
EMBEDDING_API_BASE = settings.EMBEDDING_API_BASE
EMBEDDING_API_KEY = settings.EMBEDDING_API_KEY
MILVUS_DB_URL = settings.MILVUS_DB_URL
CONSISTENCY_LEVEL = "Bounded"
AUTO_ID = True
METRIC_TYPE = "COSINE"
INDEX_TYPE = "AUTOINDEX"
SPARSE_METRIC_TYPE = "BM25"
SPARSE_INDEX_TYPE = "SPARSE_INVERTED_INDEX"
def _embedding_batch_limits() -> tuple[int, int, int]:
max_docs = max(1, int(getattr(settings, "EMBEDDING_BATCH_MAX_DOCS", 4) or 4))
max_chars = max(512, int(getattr(settings, "EMBEDDING_BATCH_MAX_CHARS", 12000) or 12000))
max_chunk = max(512, int(getattr(settings, "EMBEDDING_MAX_CHUNK_CHARS", 4000) or 4000))
return max_docs, max_chars, max_chunk
def _is_embedding_backend_oom(exc: BaseException) -> bool:
msg = str(exc).lower()
return (
"out of memory" in msg
or "npu out of memory" in msg
or "cuda out of memory" in msg
or "error code: 424" in msg
or "'code': 424" in msg
)
def _add_documents_batch_with_retry(vs: Milvus, batch: List[Document]) -> List[str]:
"""写入一批文档;远端 embedding OOM 时自动拆半重试。"""
if not batch:
return []
try:
return list(vs.add_documents(batch))
except Exception as e:
if not _is_embedding_backend_oom(e) or len(batch) <= 1:
raise
mid = max(1, len(batch) // 2)
logger.warning(
"embedding 批次 OOM拆分为 %s + %s 重试",
mid,
len(batch) - mid,
)
ids: List[str] = []
ids.extend(_add_documents_batch_with_retry(vs, batch[:mid]))
ids.extend(_add_documents_batch_with_retry(vs, batch[mid:]))
return ids
def _register_milvus_client_for_orm(client: MilvusClient) -> None:
"""pymilvus 2.6+ MilvusClient uses ConnectionManager; ORM Collection still resolves
pymilvus.orm.connections by client._using. langchain-milvus touches Collection during
Milvus.__init__, so register before constructing Milvus (bootstrap client)."""
alias = client._using
if connections.has_connection(alias):
return
cfg = client._config
connections._alias_handlers[alias] = client._handler
connections._alias_config[alias] = {
"address": cfg.address,
"user": "",
"db_name": cfg.db_name or "default",
}
# ============================================================================
# VectorStore 类(已全部改为 drop_old=False
# ============================================================================
class VectorStore:
def __init__(
self,
collection_name: str = COLLECTION_NAME,
drop_old: bool = False,
chunk_size: int = 500,
chunk_overlap: int = 50
):
self.collection_name = collection_name
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self._drop_old = drop_old
self._milvus = None
def _get_embeddings(self):
return OpenAIEmbeddings(
base_url=EMBEDDING_API_BASE,
api_key=EMBEDDING_API_KEY,
model="bge-m3",
check_embedding_ctx_length=False,
)
def _get_milvus(self, drop_old: bool = False) -> Milvus:
logger.info("【VectorStore】初始化 Milvus 混合向量存储dense + sparse")
if self._milvus is not None and not drop_old:
logger.info("【VectorStore】复用已有 Milvus 实例")
return self._milvus
if not MILVUS_DB_URL:
raise ValueError("MILVUS_DB_URL 未配置,请在 .env 中设置")
embeddings = self._get_embeddings()
logger.info("【VectorStore】Embedding 模型 bge-m3 初始化完成")
try:
# 与 langchain 内 MilvusClient 共享 ConnectionManager先注册 ORM alias否则 __init__ 内访问 Collection 会报错
_register_milvus_client_for_orm(MilvusClient(uri=MILVUS_DB_URL))
self._milvus = Milvus(
embedding_function=embeddings,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
connection_args={"uri": MILVUS_DB_URL},
collection_name=self.collection_name,
consistency_level=CONSISTENCY_LEVEL,
auto_id=AUTO_ID,
drop_old=False,
index_params=[
{"metric_type": METRIC_TYPE, "index_type": INDEX_TYPE},
{"metric_type": SPARSE_METRIC_TYPE, "index_type": SPARSE_INDEX_TYPE},
],
)
_register_milvus_client_for_orm(self._milvus.client)
logger.info("✅ Milvus 混合向量存储初始化成功")
except Exception as e:
logger.error(f"❌ Milvus 初始化失败: {str(e)}", exc_info=True)
raise
return self._milvus
# ========================================================================
# ✅ 修复版 add_documents语义友好不破坏段落不触发413
# ========================================================================
def add_documents(self, documents: List[Document]) -> List[str]:
if not documents:
logger.info("【add_documents】无文档可写入")
return []
max_docs_per_batch, max_chars_per_batch, max_chunk_chars = _embedding_batch_limits()
# ---------------------- 语义安全切分(只修问题,不破坏结构)----------------------
# 只处理【真的超长】的段落,在句子/段落边界分割,绝不乱切
safe_splitter = RecursiveCharacterTextSplitter(
chunk_size=max_chunk_chars,
chunk_overlap=min(200, max(0, max_chunk_chars // 20)),
separators=["\n\n", "\n", "", "", "", "", "", ""]
)
safe_documents = []
for doc in documents:
# 超过限制才切分
if len(doc.page_content) > max_chunk_chars:
chunks = safe_splitter.split_text(doc.page_content)
for chunk in chunks:
if chunk.strip():
safe_documents.append(Document(
page_content=chunk,
metadata=doc.metadata.copy()
))
else:
safe_documents.append(doc)
# --------------------------------------------------------------------------------
# Milvus 现有集合要求部分 metadata 字段必填;历史调用方未必都传这些字段,这里统一兜底补齐。
for idx, doc in enumerate(safe_documents):
metadata = doc.metadata or {}
if not metadata.get("doc_id"):
project_uuid = metadata.get("project_uuid") or "unknown_project"
heading = metadata.get("heading") or "chunk"
metadata["doc_id"] = f"{project_uuid}:{heading}:{idx}"
if "original_title" not in metadata:
metadata["original_title"] = metadata.get("heading") or ""
if "path" not in metadata:
metadata["path"] = ""
if "project_uuid" not in metadata:
metadata["project_uuid"] = "unknown_project"
doc.metadata = metadata
logger.info(f"【add_documents】预处理后准备写入 {len(safe_documents)} 条文档")
vs = self._get_milvus(drop_old=self._drop_old)
self._drop_old = False
ids = []
current_batch: List[Document] = []
current_batch_chars = 0
batch_num = 1
def _flush_batch() -> None:
nonlocal current_batch, current_batch_chars, batch_num
if not current_batch:
return
logger.info(
"【add_documents】写入批次 %s,数量:%s,约 %s 字符",
batch_num,
len(current_batch),
current_batch_chars,
)
try:
res = _add_documents_batch_with_retry(vs, current_batch)
ids.extend(res)
logger.info("✅ 批次写入成功,返回 ID 数:%s", len(res))
except Exception as e:
logger.error("❌ 批次写入失败: %s", e, exc_info=True)
batch_num += 1
current_batch = []
current_batch_chars = 0
for doc in safe_documents:
doc_chars = len(doc.page_content or "")
would_exceed_docs = bool(current_batch) and len(current_batch) >= max_docs_per_batch
would_exceed_chars = bool(current_batch) and (
current_batch_chars + doc_chars > max_chars_per_batch
)
if would_exceed_docs or would_exceed_chars:
_flush_batch()
current_batch.append(doc)
current_batch_chars += doc_chars
_flush_batch()
logger.info(f"【add_documents】全部完成总写入 ID 数:{len(ids)}")
return ids
def similarity_search_with_score(
self, query: str, k: int = 10, filter: Optional[str] = None
) -> List[Tuple[Document, float]]:
vs = self._get_milvus(drop_old=False)
query = query[:5000]
if filter:
return vs.similarity_search_with_score(query, k=k, filter=filter)
return vs.similarity_search_with_score(query, k=k)
def similarity_search_dense_filtered(
self,
query: str,
k: int,
filter_expr: str,
) -> List[Tuple[Document, float]]:
"""
使用 dense 向量 ANN + Milvus 标量过滤检索。
hybriddense+sparse集合上 langchain_milvus 的 filter 可能不生效,抽取侧召回用此路径保证 doc_id 隔离。
"""
from pymilvus import MilvusClient
q = (query or "")[:5000]
if not q.strip():
return []
emb = self._get_embeddings().embed_query(q)
client = MilvusClient(uri=MILVUS_DB_URL)
try:
raw = client.search(
collection_name=self.collection_name,
data=[emb],
anns_field="dense",
limit=max(1, int(k)),
filter=filter_expr,
output_fields=[
"text",
"heading",
"heading_level",
"doc_id",
"project_uuid",
"original_title",
"path",
],
)
finally:
client.close()
hits = raw[0] if raw else []
out: List[Tuple[Document, float]] = []
for hit in hits:
ent = hit.get("entity") or {}
doc = Document(
page_content=str(ent.get("text") or ""),
metadata={
"heading": ent.get("heading"),
"heading_level": ent.get("heading_level"),
"doc_id": ent.get("doc_id"),
"project_uuid": ent.get("project_uuid"),
"original_title": ent.get("original_title"),
"path": ent.get("path"),
},
)
dist = hit.get("distance")
try:
score = float(dist) if dist is not None else 0.0
except (TypeError, ValueError):
score = 0.0
out.append((doc, score))
return out
def delete_by_filter(self, filter_expr: str) -> int:
try:
from pymilvus import MilvusClient
client = MilvusClient(uri=MILVUS_DB_URL)
if not client.has_collection(self.collection_name):
return 0
# 某些集合主键字段名不叫 id例如 langchain-milvus 可能使用自定义 PK/auto_id
# 先从集合描述里找主键字段,再用于 query 计数。
pk_field = None
describe = client.describe_collection(self.collection_name)
for f in describe.get("fields", []) or []:
# 兼容不同返回结构is_primary / isPrimary / primary
if f.get("is_primary") or f.get("isPrimary") or f.get("primary"):
pk_field = f.get("name")
break
count = 0
try:
if pk_field:
res = client.query(
self.collection_name,
filter=filter_expr,
output_fields=[pk_field],
)
count = len(res)
else:
# 找不到主键字段名时也不阻断删除
count = 0
except Exception:
# 仅计数失败不影响删除
count = 0
client.delete(self.collection_name, filter=filter_expr)
client.close()
return count
except Exception as e:
logger.error(f"删除失败: {e}")
return 0
# ============================================================================
# Markdown 拆分
# ============================================================================
def split_markdown(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> List[str]:
if not text: return []
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap,
separators=["\n\n", "", "", "", "\n", "", "", ""]
)
return splitter.split_text(text)
def split_markdown_by_headings(content: str, chunk_size=300, chunk_overlap=40) -> List[Document]:
if not content: return []
docs = []
lines = content.split("\n")
current_heading = ""
current_level = 0
current_lines = []
def flush():
nonlocal current_lines, current_heading, current_level
txt = "\n".join(current_lines).strip()
if txt:
docs.append(Document(
page_content=txt,
metadata={"heading": current_heading, "heading_level": current_level}
))
current_lines = []
for line in lines:
line = line.rstrip()
m = re.match(r"^(#{1,6})\s+(.+)$", line)
if m:
flush()
current_level = len(m.group(1))
current_heading = m.group(2).strip()
else:
current_lines.append(line)
flush()
if not docs:
chunks = split_markdown(content, chunk_size, chunk_overlap)
for i, c in enumerate(chunks):
docs.append(
Document(
page_content=c,
metadata={"chunk_index": i, "heading": "", "heading_level": 0},
)
)
return docs
def process_document_to_vector_store(
doc_id: str, title: str, content: str, path: str, project_uuid: str, collection_name=COLLECTION_NAME
) -> bool:
try:
vs = VectorStore(collection_name=collection_name, drop_old=False)
docs = split_markdown_by_headings(content)
for d in docs:
d.metadata["doc_id"] = doc_id
d.metadata["original_title"] = title
d.metadata["path"] = path
d.metadata["project_uuid"] = project_uuid
vs.add_documents(docs)
return True
except Exception as e:
logger.error(f"处理文档失败: {e}")
return False
# ============================================================================
# 数据预处理
# ============================================================================
INPUT_FILE = "data/articles.jsonl"
OUTPUT_CHUNK_FILE = "data/processed/eval_chunks.jsonl"
def load_jsonl(filename: str, encoding="utf-8"):
with open(filename, encoding=encoding) as f:
for line in f:
if line.strip():
yield json.loads(line)
def write_jsonl(data, filename, append=False, ensure_ascii=False):
mode = "a" if append else "w"
with open(filename, mode, encoding="utf-8") as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=ensure_ascii) + "\n")
def clean_text(text: str) -> str:
if not isinstance(text, str): return ""
text = re.sub(r"[\x00-\x09\x0B-\x1F\x7F]", "", text)
text = re.sub(r"[\u200b-\u200f\u2028\u2029]", "", text)
text = re.sub(r"[:’“”•…–—]", "", text)
text = re.sub(r"<[^>]+>", "\n", text)
text = re.sub(r"\n+", "\n", text)
text = re.sub(r" +", " ", text)
text = re.sub(r"^[。,?!;:]", "", text)
text = re.sub(r'[^\u4e00-\u9fff_a-zA-Z0-9\s《》【】""''·!@#$%^&*()_+=[]{}|;:\'",./<>?-]', "", text)
return text.strip()
def concat_metadata_to_content(title: str, content: str, metadata: dict):
parts = [
f"标题:{title}",
f"发布时间:{metadata.get('publish_time')}",
f"作者:{metadata.get('author')}",
f"来源:{metadata.get('source')}",
]
parts = [p for p in parts if p.split("")[-1]]
return " | ".join(parts) + "\n---\n" + content.strip()
def process_all_documents(input_file, output_file, chunk_size=500, overlap=50):
docs = load_jsonl(input_file)
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap,
separators=["\n\n", "", "", "", "\n", "", "", ""])
all_chunks = []
num_docs = 0
for doc in docs:
num_docs +=1
content = clean_text(doc["content"])
chunks = splitter.split_text(content)
for i, chunk in chunks:
clean_c = clean_text(chunk)
if len(clean_c) <10: continue
all_chunks.append({
"id": f"{doc['id']}_chunk_{i}",
"doc_id": doc["id"],
"title": doc["title"],
"content": concat_metadata_to_content(doc["title"], clean_c, doc.get("metadata",{})),
"chunk_index": i,
"url": doc.get("metadata",{}).get("url","")
})
write_jsonl(all_chunks, output_file)
return {"num_docs":num_docs, "num_chunks":len(all_chunks)}
def load_chunk_jsonl(path):
res = []
with open(path, encoding="utf-8") as f:
for line in f:
if line.strip():
res.append(json.loads(line))
return res
def build_index(data, vs: VectorStore):
docs: List[Document] = []
for row in data:
c = row.pop("content", "").strip()
if len(c) < 10:
continue
docs.append(Document(page_content=c, metadata=row))
if docs:
vs.add_documents(docs)
def get_vector_store(drop_old=False):
vs = VectorStore(collection_name=COLLECTION_NAME, drop_old=drop_old)
return vs._get_milvus(drop_old=drop_old)
def search_eval(query, top_k=10):
from time import time
vs = VectorStore(drop_old=False)
st = time()
results = vs.similarity_search_with_score(query, k=top_k)
print(f"检索耗时: {time()-st:.2f}s")
return results
# ============================================================================
# 运行入口
# ============================================================================
if __name__ == "__main__":
logger.info("="*60)
logger.info("【Milvus 混合向量索引构建启动】dense + sparse(BM25)")
logger.info("="*60)
process_all_documents(INPUT_FILE, OUTPUT_CHUNK_FILE)
logger.info("✅ 文本分块处理完成")
chunk_data = load_chunk_jsonl(OUTPUT_CHUNK_FILE)
logger.info(f"✅ 加载分块数据:{len(chunk_data)}")
vs = VectorStore(drop_old=False)
build_index(chunk_data, vs)
logger.info("✅ 索引构建完成(增量写入)")
res = search_eval("测试检索内容")
logger.info(f"✅ 检索完成,命中数量:{len(res)}")
for doc, score in res:
logger.info(f"score={score:.4f} | content={doc.page_content[:80]}...")
logger.info("="*60)
logger.info("【全部执行完成】")