xxy aa98ea2623 @
Initial commit

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@
2026-06-05 18:45:29 +08:00

765 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
database/init_db.py
应用启动时初始化数据库表结构。
执行 init.sql 中的 DDL使用 IF NOT EXISTS 保证幂等。
"""
import re
from pathlib import Path
from sqlalchemy import text
from database.core import engine
# DDL 与 init_db.py 同目录database/init.sql
INIT_SQL_PATH = Path(__file__).resolve().parent / "init.sql"
INIT_TABLES = [
"projects",
"kb_directories",
"kb_documents",
"write_documents",
"doc_versions",
"element_tables",
"element_cells",
"extraction_results",
"element_extraction_results",
"element_conflicts",
"document_markdowns",
"document_chunks",
"report_templates",
"report_template_sections",
"report_generation_jobs",
"report_generation_chapters",
"departments",
"users",
"roles",
"permissions",
"role_permissions",
"user_roles",
"project_members",
"project_departments",
"fill_records",
"report_section_references",
]
_TARGET_TABLE_COLLATION = "utf8mb4_unicode_ci"
def _existing_tables(conn) -> set[str]:
return {
row[0]
for row in conn.execute(
text(
"SELECT TABLE_NAME FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE()"
)
).fetchall()
}
def _table_collation(conn, table_name: str) -> str | None:
row = conn.execute(
text(
"SELECT TABLE_COLLATION FROM information_schema.TABLES "
"WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = :table_name"
),
{"table_name": table_name},
).first()
return str(row[0]).strip() if row and row[0] else None
def _column_collation(conn, table_name: str, column_name: str) -> str | None:
row = conn.execute(
text(
"SELECT COLLATION_NAME FROM information_schema.COLUMNS "
"WHERE TABLE_SCHEMA = DATABASE() "
"AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name"
),
{"table_name": table_name, "column_name": column_name},
).first()
return str(row[0]).strip() if row and row[0] else None
def _normalize_projects_table(conn) -> None:
"""
将历史库表/列统一为 utf8mb4_unicode_ci仅在实际不一致时执行 ALTER
切勿在每次启动时对已迁移库重复 CONVERT会长时间持有 metadata lock
阻塞所有对 projects 等表的读写,并导致连接池耗尽。
"""
existing = _existing_tables(conn)
tables_to_convert = [
name
for name in INIT_TABLES
if name in existing and _table_collation(conn, name) != _TARGET_TABLE_COLLATION
]
projects_uuid_needs_fix = (
"projects" in existing
and _column_collation(conn, "projects", "uuid") != _TARGET_TABLE_COLLATION
)
if not tables_to_convert and not projects_uuid_needs_fix:
return
conn.execute(text("SET FOREIGN_KEY_CHECKS=0"))
try:
for table_name in tables_to_convert:
conn.execute(
text(
f"ALTER TABLE `{table_name}` "
"CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"
)
)
if projects_uuid_needs_fix:
conn.execute(
text(
"ALTER TABLE projects "
"MODIFY uuid VARCHAR(32) CHARACTER SET utf8mb4 "
"COLLATE utf8mb4_unicode_ci NOT NULL"
)
)
conn.commit()
finally:
conn.execute(text("SET FOREIGN_KEY_CHECKS=1"))
def _split_sql_statements(content: str) -> list[str]:
"""
按分号拆分 SQL 语句,忽略注释和空行。
简单实现:不处理字符串内的分号。
"""
# 移除单行注释
content = re.sub(r"--[^\n]*", "", content)
# 移除多行注释
content = re.sub(r"/\*.*?\*/", "", content, flags=re.DOTALL)
statements = [
s.strip()
for s in content.split(";")
if s.strip() and not s.strip().startswith("--")
]
return statements
def init_database() -> None:
"""
执行 init.sql创建表结构并按需执行缺失字段迁移。
注意init.sql 里使用了 `CREATE TABLE IF NOT EXISTS`,因此对“已存在但缺列”的旧库,
需要额外执行对应迁移脚本(例如补齐 `kb_documents.factor`)。
"""
if not INIT_SQL_PATH.exists():
return
sql_text = INIT_SQL_PATH.read_text(encoding="utf-8")
statements = _split_sql_statements(sql_text)
with engine.connect() as conn:
for stmt in statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
# 表/索引已存在时忽略Duplicate key name、already exists
err_msg = str(e).lower()
# 历史库可能缺列,导致 CREATE INDEX 报 "Key column ... doesn't exist in table"。
# 这里先跳过,后续 migrate_extraction_results.sql 会补齐列并建索引。
if (
"already exists" in err_msg
or "duplicate" in err_msg
or ("key column" in err_msg and "doesn't exist in table" in err_msg)
or "error 1072" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# 仅在字符集未达标时执行 ALTER勿在每次 CREATE TABLE projects 后重复调用)
_normalize_projects_table(conn)
# ------------------------------------------------------------------
# Missing-column migrations (idempotent via "duplicate column" ignore)
# ------------------------------------------------------------------
factor_migrate_path = Path(__file__).resolve().parent / "migrate_kb_documents_factor.sql"
if factor_migrate_path.exists():
factor_sql_text = factor_migrate_path.read_text(encoding="utf-8")
factor_statements = _split_sql_statements(factor_sql_text)
with engine.connect() as conn:
for stmt in factor_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
# MySQL: Error 1060 "Duplicate column name 'factor'"
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# Missing tables/columns migrations (kb_directories + directory_id)
# ------------------------------------------------------------------
kb_dirs_migrate_path = Path(__file__).resolve().parent / "migrate_kb_directories.sql"
if kb_dirs_migrate_path.exists():
kb_dirs_sql_text = kb_dirs_migrate_path.read_text(encoding="utf-8")
kb_dirs_statements = _split_sql_statements(kb_dirs_sql_text)
with engine.connect() as conn:
for stmt in kb_dirs_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
# MySQL 常见“已存在/重复”错误:忽略以保证幂等
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "already exists" in err_msg
or "duplicate" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# Missing tables/columns migrations (extraction_results legacy schema)
# ------------------------------------------------------------------
extraction_migrate_path = Path(__file__).resolve().parent / "migrate_extraction_results.sql"
if extraction_migrate_path.exists():
extraction_sql_text = extraction_migrate_path.read_text(encoding="utf-8")
extraction_statements = _split_sql_statements(extraction_sql_text)
with engine.connect() as conn:
for stmt in extraction_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "already exists" in err_msg
or "duplicate" in err_msg
or "check that column/key exists" in err_msg
or "error 1072" in err_msg
or "doesn't exist" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# extraction_results移除历史 run_id 外键/字段(收敛到 batch_id
# ------------------------------------------------------------------
extraction_drop_run_id_path = (
Path(__file__).resolve().parent / "migrate_extraction_results_drop_run_id.sql"
)
if extraction_drop_run_id_path.exists():
extraction_drop_run_id_sql = extraction_drop_run_id_path.read_text(encoding="utf-8")
extraction_drop_run_id_statements = _split_sql_statements(extraction_drop_run_id_sql)
with engine.connect() as conn:
for stmt in extraction_drop_run_id_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"already exists" in err_msg
or "duplicate" in err_msg
or "doesn't exist" in err_msg
or "check that column/key exists" in err_msg
or "error 1091" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_conflicts补齐 table_id / cell_id旧库缺列导致 ORM 查询 500
# ------------------------------------------------------------------
ec_migrate_path = Path(__file__).resolve().parent / "migrate_element_conflicts.sql"
if ec_migrate_path.exists():
ec_sql_text = ec_migrate_path.read_text(encoding="utf-8")
ec_statements = _split_sql_statements(ec_sql_text)
with engine.connect() as conn:
for stmt in ec_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "already exists" in err_msg
or "errno 1060" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_conflicts兼容历史 project_element_id NOT NULL改为 NULL
# ------------------------------------------------------------------
ec_project_element_id_path = (
Path(__file__).resolve().parent / "migrate_element_conflicts_project_element_id_nullable.sql"
)
if ec_project_element_id_path.exists():
ec_peid_sql = ec_project_element_id_path.read_text(encoding="utf-8")
ec_peid_stmts = _split_sql_statements(ec_peid_sql)
with engine.connect() as conn:
for stmt in ec_peid_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
conn.rollback()
raise
# ------------------------------------------------------------------
# element_conflicts兼容历史 extraction_result_id NOT NULL改为 NULL
# ------------------------------------------------------------------
ec_extraction_result_id_path = (
Path(__file__).resolve().parent / "migrate_element_conflicts_extraction_result_id_nullable.sql"
)
if ec_extraction_result_id_path.exists():
ec_erid_sql = ec_extraction_result_id_path.read_text(encoding="utf-8")
ec_erid_stmts = _split_sql_statements(ec_erid_sql)
with engine.connect() as conn:
for stmt in ec_erid_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
conn.rollback()
raise
# ------------------------------------------------------------------
# extraction_resultsextracted_at / source_line_end
# ------------------------------------------------------------------
ext_time_path = Path(__file__).resolve().parent / "migrate_extraction_results_extracted_at.sql"
if ext_time_path.exists():
ext_sql = ext_time_path.read_text(encoding="utf-8")
ext_stmts = _split_sql_statements(ext_sql)
with engine.connect() as conn:
for stmt in ext_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "already exists" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_extraction_results要素抽取结果明细表若旧库缺表则补齐
# ------------------------------------------------------------------
el_ext_path = Path(__file__).resolve().parent / "migrate_element_extraction_results.sql"
if el_ext_path.exists():
el_ext_sql = el_ext_path.read_text(encoding="utf-8")
el_ext_stmts = _split_sql_statements(el_ext_sql)
with engine.connect() as conn:
for stmt in el_ext_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"already exists" in err_msg
or "duplicate" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# project_departments项目可见部门
# ------------------------------------------------------------------
proj_dept_path = Path(__file__).resolve().parent / "migrate_project_departments.sql"
if proj_dept_path.exists():
proj_dept_sql = proj_dept_path.read_text(encoding="utf-8")
proj_dept_stmts = _split_sql_statements(proj_dept_sql)
with engine.connect() as conn:
for stmt in proj_dept_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "already exists" in err_msg or "duplicate" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# report_template_sections章节输出合同section_output_contract
# ------------------------------------------------------------------
template_section_contract_path = (
Path(__file__).resolve().parent / "migrations" / "add_section_output_contract.sql"
)
if template_section_contract_path.exists():
template_section_contract_sql = template_section_contract_path.read_text(encoding="utf-8")
template_section_contract_stmts = _split_sql_statements(template_section_contract_sql)
with engine.connect() as conn:
for stmt in template_section_contract_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# report_section_references补齐 template_id按模板过滤参考范文
# ------------------------------------------------------------------
ref_template_id_path = (
Path(__file__).resolve().parent / "migrations" / "add_ref_template_id.sql"
)
if ref_template_id_path.exists():
ref_template_id_sql = ref_template_id_path.read_text(encoding="utf-8")
ref_template_id_stmts = _split_sql_statements(ref_template_id_sql)
with engine.connect() as conn:
for stmt in ref_template_id_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if (
"duplicate column" in err_msg
or "error 1060" in err_msg
or "duplicate key name" in err_msg
or "error 1061" in err_msg
or "already exists" in err_msg
):
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# users补齐 password_hash登录注册
# ------------------------------------------------------------------
users_pwd_path = Path(__file__).resolve().parent / "migrate_users_password_hash.sql"
if users_pwd_path.exists():
users_pwd_sql = users_pwd_path.read_text(encoding="utf-8")
users_pwd_stmts = _split_sql_statements(users_pwd_sql)
with engine.connect() as conn:
for stmt in users_pwd_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# departments补齐 description部门描述
# ------------------------------------------------------------------
dept_desc_path = Path(__file__).resolve().parent / "migrate_departments_description.sql"
if dept_desc_path.exists():
dept_desc_sql = dept_desc_path.read_text(encoding="utf-8")
dept_desc_stmts = _split_sql_statements(dept_desc_sql)
with engine.connect() as conn:
for stmt in dept_desc_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# projects用户删除的标准模版表不再被「同步模版」回补
# ------------------------------------------------------------------
proj_sup_path = Path(__file__).resolve().parent / "migrate_projects_sync_suppressed_tables.sql"
if proj_sup_path.exists():
proj_sup_sql = proj_sup_path.read_text(encoding="utf-8")
proj_sup_stmts = _split_sql_statements(proj_sup_sql)
with engine.connect() as conn:
for stmt in proj_sup_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_tables用户删行后不再被「同步模版」回补
# ------------------------------------------------------------------
et_sup_path = Path(__file__).resolve().parent / "migrate_element_tables_sync_suppressed.sql"
if et_sup_path.exists():
et_sup_sql = et_sup_path.read_text(encoding="utf-8")
et_sup_stmts = _split_sql_statements(et_sup_sql)
with engine.connect() as conn:
for stmt in et_sup_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_tables自定义行顺序加行插在选中行下刷新后仍保持
# ------------------------------------------------------------------
et_row_order_path = Path(__file__).resolve().parent / "migrate_element_tables_custom_row_order.sql"
if et_row_order_path.exists():
et_row_order_sql = et_row_order_path.read_text(encoding="utf-8")
et_row_order_stmts = _split_sql_statements(et_row_order_sql)
with engine.connect() as conn:
for stmt in et_row_order_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# kb_documentsstatus 语义 v20/2/3/4仅旧库仍有 status=1 且无 status=4时执行
# ------------------------------------------------------------------
status_v2_path = Path(__file__).resolve().parent / "migrate_kb_doc_status_v2.sql"
if status_v2_path.exists():
with engine.connect() as conn:
try:
probe = conn.execute(
text(
"""
SELECT
SUM(CASE WHEN status = 1 THEN 1 ELSE 0 END) AS s1,
SUM(CASE WHEN status = 4 THEN 1 ELSE 0 END) AS s4
FROM kb_documents
"""
)
).fetchone()
s1 = int((probe[0] if probe else 0) or 0)
s4 = int((probe[1] if probe else 0) or 0)
if s1 > 0 and s4 == 0:
status_v2_sql = status_v2_path.read_text(encoding="utf-8")
for stmt in _split_sql_statements(status_v2_sql):
stmt = stmt.strip()
if not stmt:
continue
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "doesn't exist" in err_msg and "kb_documents" in err_msg:
conn.rollback()
else:
conn.rollback()
raise
# ------------------------------------------------------------------
# kb_documentsstorage_rel_path + error_message
# ------------------------------------------------------------------
kb_storage_path = Path(__file__).resolve().parent / "migrate_kb_doc_storage_path.sql"
if kb_storage_path.exists():
kb_storage_sql = kb_storage_path.read_text(encoding="utf-8")
kb_storage_stmts = _split_sql_statements(kb_storage_sql)
with engine.connect() as conn:
for stmt in kb_storage_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# kb_documentscategory (文件分类)
# ------------------------------------------------------------------
category_migrate_path = Path(__file__).resolve().parent / "migrate_kb_documents_category.sql"
if category_migrate_path.exists():
category_sql_text = category_migrate_path.read_text(encoding="utf-8")
category_statements = _split_sql_statements(category_sql_text)
with engine.connect() as conn:
for stmt in category_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# kb_documents旧版分类 → 资料清单六大类
# ------------------------------------------------------------------
category_checklist_path = Path(__file__).resolve().parent / "migrate_kb_category_checklist.sql"
if category_checklist_path.exists():
checklist_sql_text = category_checklist_path.read_text(encoding="utf-8")
checklist_statements = _split_sql_statements(checklist_sql_text)
with engine.connect() as conn:
for stmt in checklist_statements:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception:
conn.rollback()
# ------------------------------------------------------------------
# kb_documentsupload_filename上传/解压原始文件名)
# ------------------------------------------------------------------
upload_fn_path = Path(__file__).resolve().parent / "migrate_kb_documents_upload_filename.sql"
if upload_fn_path.exists():
upload_fn_sql = upload_fn_path.read_text(encoding="utf-8")
upload_fn_stmts = _split_sql_statements(upload_fn_sql)
with engine.connect() as conn:
for stmt in upload_fn_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise
# ------------------------------------------------------------------
# element_cellssource_type文档抽取 / 手工输入)
# ------------------------------------------------------------------
ec_source_type_path = Path(__file__).resolve().parent / "migrate_element_cells_source_type.sql"
if ec_source_type_path.exists():
ec_source_sql = ec_source_type_path.read_text(encoding="utf-8")
ec_source_stmts = _split_sql_statements(ec_source_sql)
with engine.connect() as conn:
for stmt in ec_source_stmts:
stmt = stmt.strip()
if not stmt:
continue
try:
conn.execute(text(stmt))
conn.commit()
except Exception as e:
err_msg = str(e).lower()
if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg:
conn.rollback()
continue
conn.rollback()
raise