551 lines
20 KiB
Python
551 lines
20 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
function/vector_store.py
|
||
向量库模块 - 与 kb_service 项目集成
|
||
已修改:drop_old 全部 = False,不会删除已有集合
|
||
✅ 已修复 413 超长 token 问题(语义友好版)
|
||
"""
|
||
|
||
import re
|
||
import json
|
||
import logging
|
||
from typing import Dict, List, Optional, Tuple
|
||
from pathlib import Path
|
||
|
||
from langchain_core.documents import Document
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
from langchain_openai import OpenAIEmbeddings
|
||
from langchain_milvus import Milvus, BM25BuiltInFunction
|
||
from pymilvus import MilvusClient, connections
|
||
|
||
from config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ============================================================================
|
||
# 配置
|
||
# ============================================================================
|
||
COLLECTION_NAME = "eval_report"
|
||
EMBEDDING_API_BASE = settings.EMBEDDING_API_BASE
|
||
EMBEDDING_API_KEY = settings.EMBEDDING_API_KEY
|
||
MILVUS_DB_URL = settings.MILVUS_DB_URL
|
||
|
||
CONSISTENCY_LEVEL = "Bounded"
|
||
AUTO_ID = True
|
||
METRIC_TYPE = "COSINE"
|
||
INDEX_TYPE = "AUTOINDEX"
|
||
SPARSE_METRIC_TYPE = "BM25"
|
||
SPARSE_INDEX_TYPE = "SPARSE_INVERTED_INDEX"
|
||
|
||
|
||
def _embedding_batch_limits() -> tuple[int, int, int]:
|
||
max_docs = max(1, int(getattr(settings, "EMBEDDING_BATCH_MAX_DOCS", 4) or 4))
|
||
max_chars = max(512, int(getattr(settings, "EMBEDDING_BATCH_MAX_CHARS", 12000) or 12000))
|
||
max_chunk = max(512, int(getattr(settings, "EMBEDDING_MAX_CHUNK_CHARS", 4000) or 4000))
|
||
return max_docs, max_chars, max_chunk
|
||
|
||
|
||
def _is_embedding_backend_oom(exc: BaseException) -> bool:
|
||
msg = str(exc).lower()
|
||
return (
|
||
"out of memory" in msg
|
||
or "npu out of memory" in msg
|
||
or "cuda out of memory" in msg
|
||
or "error code: 424" in msg
|
||
or "'code': 424" in msg
|
||
)
|
||
|
||
|
||
def _add_documents_batch_with_retry(vs: Milvus, batch: List[Document]) -> List[str]:
|
||
"""写入一批文档;远端 embedding OOM 时自动拆半重试。"""
|
||
if not batch:
|
||
return []
|
||
try:
|
||
return list(vs.add_documents(batch))
|
||
except Exception as e:
|
||
if not _is_embedding_backend_oom(e) or len(batch) <= 1:
|
||
raise
|
||
mid = max(1, len(batch) // 2)
|
||
logger.warning(
|
||
"embedding 批次 OOM,拆分为 %s + %s 重试",
|
||
mid,
|
||
len(batch) - mid,
|
||
)
|
||
ids: List[str] = []
|
||
ids.extend(_add_documents_batch_with_retry(vs, batch[:mid]))
|
||
ids.extend(_add_documents_batch_with_retry(vs, batch[mid:]))
|
||
return ids
|
||
|
||
|
||
def _register_milvus_client_for_orm(client: MilvusClient) -> None:
|
||
"""pymilvus 2.6+ MilvusClient uses ConnectionManager; ORM Collection still resolves
|
||
pymilvus.orm.connections by client._using. langchain-milvus touches Collection during
|
||
Milvus.__init__, so register before constructing Milvus (bootstrap client)."""
|
||
alias = client._using
|
||
if connections.has_connection(alias):
|
||
return
|
||
cfg = client._config
|
||
connections._alias_handlers[alias] = client._handler
|
||
connections._alias_config[alias] = {
|
||
"address": cfg.address,
|
||
"user": "",
|
||
"db_name": cfg.db_name or "default",
|
||
}
|
||
|
||
|
||
# ============================================================================
|
||
# VectorStore 类(已全部改为 drop_old=False)
|
||
# ============================================================================
|
||
|
||
class VectorStore:
|
||
def __init__(
|
||
self,
|
||
collection_name: str = COLLECTION_NAME,
|
||
drop_old: bool = False,
|
||
chunk_size: int = 500,
|
||
chunk_overlap: int = 50
|
||
):
|
||
self.collection_name = collection_name
|
||
self.chunk_size = chunk_size
|
||
self.chunk_overlap = chunk_overlap
|
||
self._drop_old = drop_old
|
||
self._milvus = None
|
||
|
||
def _get_embeddings(self):
|
||
return OpenAIEmbeddings(
|
||
base_url=EMBEDDING_API_BASE,
|
||
api_key=EMBEDDING_API_KEY,
|
||
model="bge-m3",
|
||
check_embedding_ctx_length=False,
|
||
)
|
||
|
||
def _get_milvus(self, drop_old: bool = False) -> Milvus:
|
||
logger.info("【VectorStore】初始化 Milvus 混合向量存储(dense + sparse)")
|
||
|
||
if self._milvus is not None and not drop_old:
|
||
logger.info("【VectorStore】复用已有 Milvus 实例")
|
||
return self._milvus
|
||
|
||
if not MILVUS_DB_URL:
|
||
raise ValueError("MILVUS_DB_URL 未配置,请在 .env 中设置")
|
||
|
||
embeddings = self._get_embeddings()
|
||
logger.info("【VectorStore】Embedding 模型 bge-m3 初始化完成")
|
||
|
||
try:
|
||
# 与 langchain 内 MilvusClient 共享 ConnectionManager,先注册 ORM alias,否则 __init__ 内访问 Collection 会报错
|
||
_register_milvus_client_for_orm(MilvusClient(uri=MILVUS_DB_URL))
|
||
self._milvus = Milvus(
|
||
embedding_function=embeddings,
|
||
builtin_function=BM25BuiltInFunction(),
|
||
vector_field=["dense", "sparse"],
|
||
connection_args={"uri": MILVUS_DB_URL},
|
||
collection_name=self.collection_name,
|
||
consistency_level=CONSISTENCY_LEVEL,
|
||
auto_id=AUTO_ID,
|
||
drop_old=False,
|
||
index_params=[
|
||
{"metric_type": METRIC_TYPE, "index_type": INDEX_TYPE},
|
||
{"metric_type": SPARSE_METRIC_TYPE, "index_type": SPARSE_INDEX_TYPE},
|
||
],
|
||
)
|
||
_register_milvus_client_for_orm(self._milvus.client)
|
||
logger.info("✅ Milvus 混合向量存储初始化成功")
|
||
except Exception as e:
|
||
logger.error(f"❌ Milvus 初始化失败: {str(e)}", exc_info=True)
|
||
raise
|
||
|
||
return self._milvus
|
||
|
||
# ========================================================================
|
||
# ✅ 修复版 add_documents:语义友好,不破坏段落,不触发413
|
||
# ========================================================================
|
||
def add_documents(self, documents: List[Document]) -> List[str]:
|
||
if not documents:
|
||
logger.info("【add_documents】无文档可写入")
|
||
return []
|
||
|
||
max_docs_per_batch, max_chars_per_batch, max_chunk_chars = _embedding_batch_limits()
|
||
|
||
# ---------------------- 语义安全切分(只修问题,不破坏结构)----------------------
|
||
# 只处理【真的超长】的段落,在句子/段落边界分割,绝不乱切
|
||
safe_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=max_chunk_chars,
|
||
chunk_overlap=min(200, max(0, max_chunk_chars // 20)),
|
||
separators=["\n\n", "\n", "。", "!", "?", ";", ":", ","]
|
||
)
|
||
|
||
safe_documents = []
|
||
for doc in documents:
|
||
# 超过限制才切分
|
||
if len(doc.page_content) > max_chunk_chars:
|
||
chunks = safe_splitter.split_text(doc.page_content)
|
||
for chunk in chunks:
|
||
if chunk.strip():
|
||
safe_documents.append(Document(
|
||
page_content=chunk,
|
||
metadata=doc.metadata.copy()
|
||
))
|
||
else:
|
||
safe_documents.append(doc)
|
||
# --------------------------------------------------------------------------------
|
||
|
||
# Milvus 现有集合要求部分 metadata 字段必填;历史调用方未必都传这些字段,这里统一兜底补齐。
|
||
for idx, doc in enumerate(safe_documents):
|
||
metadata = doc.metadata or {}
|
||
if not metadata.get("doc_id"):
|
||
project_uuid = metadata.get("project_uuid") or "unknown_project"
|
||
heading = metadata.get("heading") or "chunk"
|
||
metadata["doc_id"] = f"{project_uuid}:{heading}:{idx}"
|
||
if "original_title" not in metadata:
|
||
metadata["original_title"] = metadata.get("heading") or ""
|
||
if "path" not in metadata:
|
||
metadata["path"] = ""
|
||
if "project_uuid" not in metadata:
|
||
metadata["project_uuid"] = "unknown_project"
|
||
doc.metadata = metadata
|
||
|
||
logger.info(f"【add_documents】预处理后准备写入 {len(safe_documents)} 条文档")
|
||
vs = self._get_milvus(drop_old=self._drop_old)
|
||
self._drop_old = False
|
||
|
||
ids = []
|
||
current_batch: List[Document] = []
|
||
current_batch_chars = 0
|
||
batch_num = 1
|
||
|
||
def _flush_batch() -> None:
|
||
nonlocal current_batch, current_batch_chars, batch_num
|
||
if not current_batch:
|
||
return
|
||
logger.info(
|
||
"【add_documents】写入批次 %s,数量:%s,约 %s 字符",
|
||
batch_num,
|
||
len(current_batch),
|
||
current_batch_chars,
|
||
)
|
||
try:
|
||
res = _add_documents_batch_with_retry(vs, current_batch)
|
||
ids.extend(res)
|
||
logger.info("✅ 批次写入成功,返回 ID 数:%s", len(res))
|
||
except Exception as e:
|
||
logger.error("❌ 批次写入失败: %s", e, exc_info=True)
|
||
batch_num += 1
|
||
current_batch = []
|
||
current_batch_chars = 0
|
||
|
||
for doc in safe_documents:
|
||
doc_chars = len(doc.page_content or "")
|
||
would_exceed_docs = bool(current_batch) and len(current_batch) >= max_docs_per_batch
|
||
would_exceed_chars = bool(current_batch) and (
|
||
current_batch_chars + doc_chars > max_chars_per_batch
|
||
)
|
||
if would_exceed_docs or would_exceed_chars:
|
||
_flush_batch()
|
||
current_batch.append(doc)
|
||
current_batch_chars += doc_chars
|
||
|
||
_flush_batch()
|
||
|
||
logger.info(f"【add_documents】全部完成,总写入 ID 数:{len(ids)}")
|
||
return ids
|
||
|
||
def similarity_search_with_score(
|
||
self, query: str, k: int = 10, filter: Optional[str] = None
|
||
) -> List[Tuple[Document, float]]:
|
||
vs = self._get_milvus(drop_old=False)
|
||
query = query[:5000]
|
||
if filter:
|
||
return vs.similarity_search_with_score(query, k=k, filter=filter)
|
||
return vs.similarity_search_with_score(query, k=k)
|
||
|
||
def similarity_search_dense_filtered(
|
||
self,
|
||
query: str,
|
||
k: int,
|
||
filter_expr: str,
|
||
) -> List[Tuple[Document, float]]:
|
||
"""
|
||
使用 dense 向量 ANN + Milvus 标量过滤检索。
|
||
hybrid(dense+sparse)集合上 langchain_milvus 的 filter 可能不生效,抽取侧召回用此路径保证 doc_id 隔离。
|
||
"""
|
||
from pymilvus import MilvusClient
|
||
|
||
q = (query or "")[:5000]
|
||
if not q.strip():
|
||
return []
|
||
emb = self._get_embeddings().embed_query(q)
|
||
client = MilvusClient(uri=MILVUS_DB_URL)
|
||
try:
|
||
raw = client.search(
|
||
collection_name=self.collection_name,
|
||
data=[emb],
|
||
anns_field="dense",
|
||
limit=max(1, int(k)),
|
||
filter=filter_expr,
|
||
output_fields=[
|
||
"text",
|
||
"heading",
|
||
"heading_level",
|
||
"doc_id",
|
||
"project_uuid",
|
||
"original_title",
|
||
"path",
|
||
],
|
||
)
|
||
finally:
|
||
client.close()
|
||
hits = raw[0] if raw else []
|
||
out: List[Tuple[Document, float]] = []
|
||
for hit in hits:
|
||
ent = hit.get("entity") or {}
|
||
doc = Document(
|
||
page_content=str(ent.get("text") or ""),
|
||
metadata={
|
||
"heading": ent.get("heading"),
|
||
"heading_level": ent.get("heading_level"),
|
||
"doc_id": ent.get("doc_id"),
|
||
"project_uuid": ent.get("project_uuid"),
|
||
"original_title": ent.get("original_title"),
|
||
"path": ent.get("path"),
|
||
},
|
||
)
|
||
dist = hit.get("distance")
|
||
try:
|
||
score = float(dist) if dist is not None else 0.0
|
||
except (TypeError, ValueError):
|
||
score = 0.0
|
||
out.append((doc, score))
|
||
return out
|
||
|
||
def delete_by_filter(self, filter_expr: str) -> int:
|
||
try:
|
||
from pymilvus import MilvusClient
|
||
client = MilvusClient(uri=MILVUS_DB_URL)
|
||
if not client.has_collection(self.collection_name):
|
||
return 0
|
||
# 某些集合主键字段名不叫 id(例如 langchain-milvus 可能使用自定义 PK/auto_id)。
|
||
# 先从集合描述里找主键字段,再用于 query 计数。
|
||
pk_field = None
|
||
describe = client.describe_collection(self.collection_name)
|
||
for f in describe.get("fields", []) or []:
|
||
# 兼容不同返回结构:is_primary / isPrimary / primary
|
||
if f.get("is_primary") or f.get("isPrimary") or f.get("primary"):
|
||
pk_field = f.get("name")
|
||
break
|
||
|
||
count = 0
|
||
try:
|
||
if pk_field:
|
||
res = client.query(
|
||
self.collection_name,
|
||
filter=filter_expr,
|
||
output_fields=[pk_field],
|
||
)
|
||
count = len(res)
|
||
else:
|
||
# 找不到主键字段名时也不阻断删除
|
||
count = 0
|
||
except Exception:
|
||
# 仅计数失败不影响删除
|
||
count = 0
|
||
|
||
client.delete(self.collection_name, filter=filter_expr)
|
||
client.close()
|
||
return count
|
||
except Exception as e:
|
||
logger.error(f"删除失败: {e}")
|
||
return 0
|
||
|
||
|
||
# ============================================================================
|
||
# Markdown 拆分
|
||
# ============================================================================
|
||
|
||
def split_markdown(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> List[str]:
|
||
if not text: return []
|
||
splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=chunk_size, chunk_overlap=chunk_overlap,
|
||
separators=["\n\n", "。", "?", "!", "\n", ";", ":", ","]
|
||
)
|
||
return splitter.split_text(text)
|
||
|
||
def split_markdown_by_headings(content: str, chunk_size=300, chunk_overlap=40) -> List[Document]:
|
||
if not content: return []
|
||
docs = []
|
||
lines = content.split("\n")
|
||
current_heading = ""
|
||
current_level = 0
|
||
current_lines = []
|
||
|
||
def flush():
|
||
nonlocal current_lines, current_heading, current_level
|
||
txt = "\n".join(current_lines).strip()
|
||
if txt:
|
||
docs.append(Document(
|
||
page_content=txt,
|
||
metadata={"heading": current_heading, "heading_level": current_level}
|
||
))
|
||
current_lines = []
|
||
|
||
for line in lines:
|
||
line = line.rstrip()
|
||
m = re.match(r"^(#{1,6})\s+(.+)$", line)
|
||
if m:
|
||
flush()
|
||
current_level = len(m.group(1))
|
||
current_heading = m.group(2).strip()
|
||
else:
|
||
current_lines.append(line)
|
||
flush()
|
||
|
||
if not docs:
|
||
chunks = split_markdown(content, chunk_size, chunk_overlap)
|
||
for i, c in enumerate(chunks):
|
||
docs.append(
|
||
Document(
|
||
page_content=c,
|
||
metadata={"chunk_index": i, "heading": "", "heading_level": 0},
|
||
)
|
||
)
|
||
return docs
|
||
|
||
def process_document_to_vector_store(
|
||
doc_id: str, title: str, content: str, path: str, project_uuid: str, collection_name=COLLECTION_NAME
|
||
) -> bool:
|
||
try:
|
||
vs = VectorStore(collection_name=collection_name, drop_old=False)
|
||
docs = split_markdown_by_headings(content)
|
||
for d in docs:
|
||
d.metadata["doc_id"] = doc_id
|
||
d.metadata["original_title"] = title
|
||
d.metadata["path"] = path
|
||
d.metadata["project_uuid"] = project_uuid
|
||
vs.add_documents(docs)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"处理文档失败: {e}")
|
||
return False
|
||
|
||
# ============================================================================
|
||
# 数据预处理
|
||
# ============================================================================
|
||
|
||
INPUT_FILE = "data/articles.jsonl"
|
||
OUTPUT_CHUNK_FILE = "data/processed/eval_chunks.jsonl"
|
||
|
||
def load_jsonl(filename: str, encoding="utf-8"):
|
||
with open(filename, encoding=encoding) as f:
|
||
for line in f:
|
||
if line.strip():
|
||
yield json.loads(line)
|
||
|
||
def write_jsonl(data, filename, append=False, ensure_ascii=False):
|
||
mode = "a" if append else "w"
|
||
with open(filename, mode, encoding="utf-8") as f:
|
||
for item in data:
|
||
f.write(json.dumps(item, ensure_ascii=ensure_ascii) + "\n")
|
||
|
||
def clean_text(text: str) -> str:
|
||
if not isinstance(text, str): return ""
|
||
text = re.sub(r"[\x00-\x09\x0B-\x1F\x7F]", "", text)
|
||
text = re.sub(r"[\u200b-\u200f\u2028\u2029]", "", text)
|
||
text = re.sub(r"[:’“â€â€¢â€¦â€“—]", "", text)
|
||
text = re.sub(r"<[^>]+>", "\n", text)
|
||
text = re.sub(r"\n+", "\n", text)
|
||
text = re.sub(r" +", " ", text)
|
||
text = re.sub(r"^[。,?!;:]", "", text)
|
||
text = re.sub(r'[^\u4e00-\u9fff_a-zA-Z0-9\s,。!?;:、()《》【】""''·!@#$%^&*()_+=[]{}|;:\'",./<>?-]', "", text)
|
||
return text.strip()
|
||
|
||
def concat_metadata_to_content(title: str, content: str, metadata: dict):
|
||
parts = [
|
||
f"标题:{title}",
|
||
f"发布时间:{metadata.get('publish_time')}",
|
||
f"作者:{metadata.get('author')}",
|
||
f"来源:{metadata.get('source')}",
|
||
]
|
||
parts = [p for p in parts if p.split(":")[-1]]
|
||
return " | ".join(parts) + "\n---\n" + content.strip()
|
||
|
||
def process_all_documents(input_file, output_file, chunk_size=500, overlap=50):
|
||
docs = load_jsonl(input_file)
|
||
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=overlap,
|
||
separators=["\n\n", "。", "?", "!", "\n", ";", ":", ","])
|
||
all_chunks = []
|
||
num_docs = 0
|
||
for doc in docs:
|
||
num_docs +=1
|
||
content = clean_text(doc["content"])
|
||
chunks = splitter.split_text(content)
|
||
for i, chunk in chunks:
|
||
clean_c = clean_text(chunk)
|
||
if len(clean_c) <10: continue
|
||
all_chunks.append({
|
||
"id": f"{doc['id']}_chunk_{i}",
|
||
"doc_id": doc["id"],
|
||
"title": doc["title"],
|
||
"content": concat_metadata_to_content(doc["title"], clean_c, doc.get("metadata",{})),
|
||
"chunk_index": i,
|
||
"url": doc.get("metadata",{}).get("url","")
|
||
})
|
||
write_jsonl(all_chunks, output_file)
|
||
return {"num_docs":num_docs, "num_chunks":len(all_chunks)}
|
||
|
||
def load_chunk_jsonl(path):
|
||
res = []
|
||
with open(path, encoding="utf-8") as f:
|
||
for line in f:
|
||
if line.strip():
|
||
res.append(json.loads(line))
|
||
return res
|
||
|
||
def build_index(data, vs: VectorStore):
|
||
docs: List[Document] = []
|
||
for row in data:
|
||
c = row.pop("content", "").strip()
|
||
if len(c) < 10:
|
||
continue
|
||
docs.append(Document(page_content=c, metadata=row))
|
||
if docs:
|
||
vs.add_documents(docs)
|
||
|
||
def get_vector_store(drop_old=False):
|
||
vs = VectorStore(collection_name=COLLECTION_NAME, drop_old=drop_old)
|
||
return vs._get_milvus(drop_old=drop_old)
|
||
|
||
def search_eval(query, top_k=10):
|
||
from time import time
|
||
vs = VectorStore(drop_old=False)
|
||
st = time()
|
||
results = vs.similarity_search_with_score(query, k=top_k)
|
||
print(f"检索耗时: {time()-st:.2f}s")
|
||
return results
|
||
|
||
# ============================================================================
|
||
# 运行入口
|
||
# ============================================================================
|
||
if __name__ == "__main__":
|
||
logger.info("="*60)
|
||
logger.info("【Milvus 混合向量索引构建启动】dense + sparse(BM25)")
|
||
logger.info("="*60)
|
||
|
||
process_all_documents(INPUT_FILE, OUTPUT_CHUNK_FILE)
|
||
logger.info("✅ 文本分块处理完成")
|
||
|
||
chunk_data = load_chunk_jsonl(OUTPUT_CHUNK_FILE)
|
||
logger.info(f"✅ 加载分块数据:{len(chunk_data)} 条")
|
||
|
||
vs = VectorStore(drop_old=False)
|
||
build_index(chunk_data, vs)
|
||
logger.info("✅ 索引构建完成(增量写入)")
|
||
|
||
res = search_eval("测试检索内容")
|
||
logger.info(f"✅ 检索完成,命中数量:{len(res)}")
|
||
for doc, score in res:
|
||
logger.info(f"score={score:.4f} | content={doc.page_content[:80]}...")
|
||
|
||
logger.info("="*60)
|
||
logger.info("【全部执行完成】")
|