""" database/init_db.py 应用启动时初始化数据库表结构。 执行 init.sql 中的 DDL,使用 IF NOT EXISTS 保证幂等。 """ import re from pathlib import Path from sqlalchemy import text from database.core import engine # DDL 与 init_db.py 同目录:database/init.sql INIT_SQL_PATH = Path(__file__).resolve().parent / "init.sql" INIT_TABLES = [ "projects", "kb_directories", "kb_documents", "write_documents", "doc_versions", "element_tables", "element_cells", "extraction_results", "element_extraction_results", "element_conflicts", "document_markdowns", "document_chunks", "report_templates", "report_template_sections", "report_generation_jobs", "report_generation_chapters", "departments", "users", "roles", "permissions", "role_permissions", "user_roles", "project_members", "project_departments", "fill_records", "report_section_references", ] _TARGET_TABLE_COLLATION = "utf8mb4_unicode_ci" def _existing_tables(conn) -> set[str]: return { row[0] for row in conn.execute( text( "SELECT TABLE_NAME FROM information_schema.TABLES " "WHERE TABLE_SCHEMA = DATABASE()" ) ).fetchall() } def _table_collation(conn, table_name: str) -> str | None: row = conn.execute( text( "SELECT TABLE_COLLATION FROM information_schema.TABLES " "WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = :table_name" ), {"table_name": table_name}, ).first() return str(row[0]).strip() if row and row[0] else None def _column_collation(conn, table_name: str, column_name: str) -> str | None: row = conn.execute( text( "SELECT COLLATION_NAME FROM information_schema.COLUMNS " "WHERE TABLE_SCHEMA = DATABASE() " "AND TABLE_NAME = :table_name AND COLUMN_NAME = :column_name" ), {"table_name": table_name, "column_name": column_name}, ).first() return str(row[0]).strip() if row and row[0] else None def _normalize_projects_table(conn) -> None: """ 将历史库表/列统一为 utf8mb4_unicode_ci(仅在实际不一致时执行 ALTER)。 切勿在每次启动时对已迁移库重复 CONVERT:会长时间持有 metadata lock, 阻塞所有对 projects 等表的读写,并导致连接池耗尽。 """ existing = _existing_tables(conn) tables_to_convert = [ name for name in INIT_TABLES if name in existing and _table_collation(conn, name) != _TARGET_TABLE_COLLATION ] projects_uuid_needs_fix = ( "projects" in existing and _column_collation(conn, "projects", "uuid") != _TARGET_TABLE_COLLATION ) if not tables_to_convert and not projects_uuid_needs_fix: return conn.execute(text("SET FOREIGN_KEY_CHECKS=0")) try: for table_name in tables_to_convert: conn.execute( text( f"ALTER TABLE `{table_name}` " "CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" ) ) if projects_uuid_needs_fix: conn.execute( text( "ALTER TABLE projects " "MODIFY uuid VARCHAR(32) CHARACTER SET utf8mb4 " "COLLATE utf8mb4_unicode_ci NOT NULL" ) ) conn.commit() finally: conn.execute(text("SET FOREIGN_KEY_CHECKS=1")) def _split_sql_statements(content: str) -> list[str]: """ 按分号拆分 SQL 语句,忽略注释和空行。 简单实现:不处理字符串内的分号。 """ # 移除单行注释 content = re.sub(r"--[^\n]*", "", content) # 移除多行注释 content = re.sub(r"/\*.*?\*/", "", content, flags=re.DOTALL) statements = [ s.strip() for s in content.split(";") if s.strip() and not s.strip().startswith("--") ] return statements def init_database() -> None: """ 执行 init.sql,创建表结构,并按需执行缺失字段迁移。 注意:init.sql 里使用了 `CREATE TABLE IF NOT EXISTS`,因此对“已存在但缺列”的旧库, 需要额外执行对应迁移脚本(例如补齐 `kb_documents.factor`)。 """ if not INIT_SQL_PATH.exists(): return sql_text = INIT_SQL_PATH.read_text(encoding="utf-8") statements = _split_sql_statements(sql_text) with engine.connect() as conn: for stmt in statements: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: # 表/索引已存在时忽略(Duplicate key name、already exists) err_msg = str(e).lower() # 历史库可能缺列,导致 CREATE INDEX 报 "Key column ... doesn't exist in table"。 # 这里先跳过,后续 migrate_extraction_results.sql 会补齐列并建索引。 if ( "already exists" in err_msg or "duplicate" in err_msg or ("key column" in err_msg and "doesn't exist in table" in err_msg) or "error 1072" in err_msg ): conn.rollback() continue conn.rollback() raise # 仅在字符集未达标时执行 ALTER(勿在每次 CREATE TABLE projects 后重复调用) _normalize_projects_table(conn) # ------------------------------------------------------------------ # Missing-column migrations (idempotent via "duplicate column" ignore) # ------------------------------------------------------------------ factor_migrate_path = Path(__file__).resolve().parent / "migrate_kb_documents_factor.sql" if factor_migrate_path.exists(): factor_sql_text = factor_migrate_path.read_text(encoding="utf-8") factor_statements = _split_sql_statements(factor_sql_text) with engine.connect() as conn: for stmt in factor_statements: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: # MySQL: Error 1060 "Duplicate column name 'factor'" err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # Missing tables/columns migrations (kb_directories + directory_id) # ------------------------------------------------------------------ kb_dirs_migrate_path = Path(__file__).resolve().parent / "migrate_kb_directories.sql" if kb_dirs_migrate_path.exists(): kb_dirs_sql_text = kb_dirs_migrate_path.read_text(encoding="utf-8") kb_dirs_statements = _split_sql_statements(kb_dirs_sql_text) with engine.connect() as conn: for stmt in kb_dirs_statements: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() # MySQL 常见“已存在/重复”错误:忽略以保证幂等 if ( "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg or "duplicate" in err_msg ): conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # Missing tables/columns migrations (extraction_results legacy schema) # ------------------------------------------------------------------ extraction_migrate_path = Path(__file__).resolve().parent / "migrate_extraction_results.sql" if extraction_migrate_path.exists(): extraction_sql_text = extraction_migrate_path.read_text(encoding="utf-8") extraction_statements = _split_sql_statements(extraction_sql_text) with engine.connect() as conn: for stmt in extraction_statements: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if ( "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg or "duplicate" in err_msg or "check that column/key exists" in err_msg or "error 1072" in err_msg or "doesn't exist" in err_msg ): conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # extraction_results:移除历史 run_id 外键/字段(收敛到 batch_id) # ------------------------------------------------------------------ extraction_drop_run_id_path = ( Path(__file__).resolve().parent / "migrate_extraction_results_drop_run_id.sql" ) if extraction_drop_run_id_path.exists(): extraction_drop_run_id_sql = extraction_drop_run_id_path.read_text(encoding="utf-8") extraction_drop_run_id_statements = _split_sql_statements(extraction_drop_run_id_sql) with engine.connect() as conn: for stmt in extraction_drop_run_id_statements: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if ( "already exists" in err_msg or "duplicate" in err_msg or "doesn't exist" in err_msg or "check that column/key exists" in err_msg or "error 1091" in err_msg ): conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # element_conflicts:补齐 table_id / cell_id(旧库缺列导致 ORM 查询 500) # ------------------------------------------------------------------ ec_migrate_path = Path(__file__).resolve().parent / "migrate_element_conflicts.sql" if ec_migrate_path.exists(): ec_sql_text = ec_migrate_path.read_text(encoding="utf-8") ec_statements = _split_sql_statements(ec_sql_text) with engine.connect() as conn: for stmt in ec_statements: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if ( "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg or "errno 1060" in err_msg ): conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # element_conflicts:兼容历史 project_element_id NOT NULL(改为 NULL) # ------------------------------------------------------------------ ec_project_element_id_path = ( Path(__file__).resolve().parent / "migrate_element_conflicts_project_element_id_nullable.sql" ) if ec_project_element_id_path.exists(): ec_peid_sql = ec_project_element_id_path.read_text(encoding="utf-8") ec_peid_stmts = _split_sql_statements(ec_peid_sql) with engine.connect() as conn: for stmt in ec_peid_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception: conn.rollback() raise # ------------------------------------------------------------------ # element_conflicts:兼容历史 extraction_result_id NOT NULL(改为 NULL) # ------------------------------------------------------------------ ec_extraction_result_id_path = ( Path(__file__).resolve().parent / "migrate_element_conflicts_extraction_result_id_nullable.sql" ) if ec_extraction_result_id_path.exists(): ec_erid_sql = ec_extraction_result_id_path.read_text(encoding="utf-8") ec_erid_stmts = _split_sql_statements(ec_erid_sql) with engine.connect() as conn: for stmt in ec_erid_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception: conn.rollback() raise # ------------------------------------------------------------------ # extraction_results:extracted_at / source_line_end # ------------------------------------------------------------------ ext_time_path = Path(__file__).resolve().parent / "migrate_extraction_results_extracted_at.sql" if ext_time_path.exists(): ext_sql = ext_time_path.read_text(encoding="utf-8") ext_stmts = _split_sql_statements(ext_sql) with engine.connect() as conn: for stmt in ext_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if ( "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg ): conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # element_extraction_results:要素抽取结果明细表(若旧库缺表则补齐) # ------------------------------------------------------------------ el_ext_path = Path(__file__).resolve().parent / "migrate_element_extraction_results.sql" if el_ext_path.exists(): el_ext_sql = el_ext_path.read_text(encoding="utf-8") el_ext_stmts = _split_sql_statements(el_ext_sql) with engine.connect() as conn: for stmt in el_ext_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if ( "already exists" in err_msg or "duplicate" in err_msg ): conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # project_departments:项目可见部门 # ------------------------------------------------------------------ proj_dept_path = Path(__file__).resolve().parent / "migrate_project_departments.sql" if proj_dept_path.exists(): proj_dept_sql = proj_dept_path.read_text(encoding="utf-8") proj_dept_stmts = _split_sql_statements(proj_dept_sql) with engine.connect() as conn: for stmt in proj_dept_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "already exists" in err_msg or "duplicate" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # report_template_sections:章节输出合同(section_output_contract) # ------------------------------------------------------------------ template_section_contract_path = ( Path(__file__).resolve().parent / "migrations" / "add_section_output_contract.sql" ) if template_section_contract_path.exists(): template_section_contract_sql = template_section_contract_path.read_text(encoding="utf-8") template_section_contract_stmts = _split_sql_statements(template_section_contract_sql) with engine.connect() as conn: for stmt in template_section_contract_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # report_section_references:补齐 template_id(按模板过滤参考范文) # ------------------------------------------------------------------ ref_template_id_path = ( Path(__file__).resolve().parent / "migrations" / "add_ref_template_id.sql" ) if ref_template_id_path.exists(): ref_template_id_sql = ref_template_id_path.read_text(encoding="utf-8") ref_template_id_stmts = _split_sql_statements(ref_template_id_sql) with engine.connect() as conn: for stmt in ref_template_id_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if ( "duplicate column" in err_msg or "error 1060" in err_msg or "duplicate key name" in err_msg or "error 1061" in err_msg or "already exists" in err_msg ): conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # users:补齐 password_hash(登录注册) # ------------------------------------------------------------------ users_pwd_path = Path(__file__).resolve().parent / "migrate_users_password_hash.sql" if users_pwd_path.exists(): users_pwd_sql = users_pwd_path.read_text(encoding="utf-8") users_pwd_stmts = _split_sql_statements(users_pwd_sql) with engine.connect() as conn: for stmt in users_pwd_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # departments:补齐 description(部门描述) # ------------------------------------------------------------------ dept_desc_path = Path(__file__).resolve().parent / "migrate_departments_description.sql" if dept_desc_path.exists(): dept_desc_sql = dept_desc_path.read_text(encoding="utf-8") dept_desc_stmts = _split_sql_statements(dept_desc_sql) with engine.connect() as conn: for stmt in dept_desc_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # projects:用户删除的标准模版表不再被「同步模版」回补 # ------------------------------------------------------------------ proj_sup_path = Path(__file__).resolve().parent / "migrate_projects_sync_suppressed_tables.sql" if proj_sup_path.exists(): proj_sup_sql = proj_sup_path.read_text(encoding="utf-8") proj_sup_stmts = _split_sql_statements(proj_sup_sql) with engine.connect() as conn: for stmt in proj_sup_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # element_tables:用户删行后不再被「同步模版」回补 # ------------------------------------------------------------------ et_sup_path = Path(__file__).resolve().parent / "migrate_element_tables_sync_suppressed.sql" if et_sup_path.exists(): et_sup_sql = et_sup_path.read_text(encoding="utf-8") et_sup_stmts = _split_sql_statements(et_sup_sql) with engine.connect() as conn: for stmt in et_sup_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # element_tables:自定义行顺序(加行插在选中行下;刷新后仍保持) # ------------------------------------------------------------------ et_row_order_path = Path(__file__).resolve().parent / "migrate_element_tables_custom_row_order.sql" if et_row_order_path.exists(): et_row_order_sql = et_row_order_path.read_text(encoding="utf-8") et_row_order_stmts = _split_sql_statements(et_row_order_sql) with engine.connect() as conn: for stmt in et_row_order_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # kb_documents:status 语义 v2(0/2/3/4),仅旧库(仍有 status=1 且无 status=4)时执行 # ------------------------------------------------------------------ status_v2_path = Path(__file__).resolve().parent / "migrate_kb_doc_status_v2.sql" if status_v2_path.exists(): with engine.connect() as conn: try: probe = conn.execute( text( """ SELECT SUM(CASE WHEN status = 1 THEN 1 ELSE 0 END) AS s1, SUM(CASE WHEN status = 4 THEN 1 ELSE 0 END) AS s4 FROM kb_documents """ ) ).fetchone() s1 = int((probe[0] if probe else 0) or 0) s4 = int((probe[1] if probe else 0) or 0) if s1 > 0 and s4 == 0: status_v2_sql = status_v2_path.read_text(encoding="utf-8") for stmt in _split_sql_statements(status_v2_sql): stmt = stmt.strip() if not stmt: continue conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "doesn't exist" in err_msg and "kb_documents" in err_msg: conn.rollback() else: conn.rollback() raise # ------------------------------------------------------------------ # kb_documents:storage_rel_path + error_message # ------------------------------------------------------------------ kb_storage_path = Path(__file__).resolve().parent / "migrate_kb_doc_storage_path.sql" if kb_storage_path.exists(): kb_storage_sql = kb_storage_path.read_text(encoding="utf-8") kb_storage_stmts = _split_sql_statements(kb_storage_sql) with engine.connect() as conn: for stmt in kb_storage_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # kb_documents:category (文件分类) # ------------------------------------------------------------------ category_migrate_path = Path(__file__).resolve().parent / "migrate_kb_documents_category.sql" if category_migrate_path.exists(): category_sql_text = category_migrate_path.read_text(encoding="utf-8") category_statements = _split_sql_statements(category_sql_text) with engine.connect() as conn: for stmt in category_statements: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # kb_documents:旧版分类 → 资料清单六大类 # ------------------------------------------------------------------ category_checklist_path = Path(__file__).resolve().parent / "migrate_kb_category_checklist.sql" if category_checklist_path.exists(): checklist_sql_text = category_checklist_path.read_text(encoding="utf-8") checklist_statements = _split_sql_statements(checklist_sql_text) with engine.connect() as conn: for stmt in checklist_statements: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception: conn.rollback() # ------------------------------------------------------------------ # kb_documents:upload_filename(上传/解压原始文件名) # ------------------------------------------------------------------ upload_fn_path = Path(__file__).resolve().parent / "migrate_kb_documents_upload_filename.sql" if upload_fn_path.exists(): upload_fn_sql = upload_fn_path.read_text(encoding="utf-8") upload_fn_stmts = _split_sql_statements(upload_fn_sql) with engine.connect() as conn: for stmt in upload_fn_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise # ------------------------------------------------------------------ # element_cells:source_type(文档抽取 / 手工输入) # ------------------------------------------------------------------ ec_source_type_path = Path(__file__).resolve().parent / "migrate_element_cells_source_type.sql" if ec_source_type_path.exists(): ec_source_sql = ec_source_type_path.read_text(encoding="utf-8") ec_source_stmts = _split_sql_statements(ec_source_sql) with engine.connect() as conn: for stmt in ec_source_stmts: stmt = stmt.strip() if not stmt: continue try: conn.execute(text(stmt)) conn.commit() except Exception as e: err_msg = str(e).lower() if "duplicate column" in err_msg or "error 1060" in err_msg or "already exists" in err_msg: conn.rollback() continue conn.rollback() raise