database_config.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """
  2. 数据库配置管理
  3. 支持Qdrant向量数据库和Neo4j图数据库的配置
  4. """
  5. import os
  6. from dotenv import load_dotenv
  7. from typing import Dict, Any, Optional
  8. from pydantic import BaseModel, Field
  9. import logging
  10. logger = logging.getLogger(__name__)
  11. # Load environment variables early so DB configs pick them up
  12. load_dotenv()
  13. class QdrantConfig(BaseModel):
  14. """Qdrant向量数据库配置"""
  15. # 连接配置
  16. url: Optional[str] = Field(
  17. default=None,
  18. description="Qdrant服务URL (云服务或自定义URL)"
  19. )
  20. api_key: Optional[str] = Field(
  21. default=None,
  22. description="Qdrant API密钥 (云服务需要)"
  23. )
  24. # 集合配置
  25. collection_name: str = Field(
  26. default="hello_agents_vectors",
  27. description="向量集合名称"
  28. )
  29. vector_size: int = Field(
  30. default=384,
  31. description="向量维度"
  32. )
  33. distance: str = Field(
  34. default="cosine",
  35. description="距离度量方式 (cosine, dot, euclidean)"
  36. )
  37. # 连接配置
  38. timeout: int = Field(
  39. default=30,
  40. description="连接超时时间(秒)"
  41. )
  42. @classmethod
  43. def from_env(cls) -> "QdrantConfig":
  44. """从环境变量创建配置"""
  45. return cls(
  46. url=os.getenv("QDRANT_URL"),
  47. api_key=os.getenv("QDRANT_API_KEY"),
  48. collection_name=os.getenv("QDRANT_COLLECTION", "hello_agents_vectors"),
  49. vector_size=int(os.getenv("QDRANT_VECTOR_SIZE", "384")),
  50. distance=os.getenv("QDRANT_DISTANCE", "cosine"),
  51. timeout=int(os.getenv("QDRANT_TIMEOUT", "30"))
  52. )
  53. def to_dict(self) -> Dict[str, Any]:
  54. """转换为字典"""
  55. return self.model_dump(exclude_none=True)
  56. class Neo4jConfig(BaseModel):
  57. """Neo4j图数据库配置"""
  58. # 连接配置
  59. uri: str = Field(
  60. default="bolt://localhost:7687",
  61. description="Neo4j连接URI"
  62. )
  63. username: str = Field(
  64. default="neo4j",
  65. description="用户名"
  66. )
  67. password: str = Field(
  68. default="hello-agents-password",
  69. description="密码"
  70. )
  71. database: str = Field(
  72. default="neo4j",
  73. description="数据库名称"
  74. )
  75. # 连接池配置
  76. max_connection_lifetime: int = Field(
  77. default=3600,
  78. description="最大连接生命周期(秒)"
  79. )
  80. max_connection_pool_size: int = Field(
  81. default=50,
  82. description="最大连接池大小"
  83. )
  84. connection_acquisition_timeout: int = Field(
  85. default=60,
  86. description="连接获取超时(秒)"
  87. )
  88. @classmethod
  89. def from_env(cls) -> "Neo4jConfig":
  90. """从环境变量创建配置"""
  91. return cls(
  92. uri=os.getenv("NEO4J_URI", "bolt://localhost:7687"),
  93. username=os.getenv("NEO4J_USERNAME", "neo4j"),
  94. password=os.getenv("NEO4J_PASSWORD", "hello-agents-password"),
  95. database=os.getenv("NEO4J_DATABASE", "neo4j"),
  96. max_connection_lifetime=int(os.getenv("NEO4J_MAX_CONNECTION_LIFETIME", "3600")),
  97. max_connection_pool_size=int(os.getenv("NEO4J_MAX_CONNECTION_POOL_SIZE", "50")),
  98. connection_acquisition_timeout=int(os.getenv("NEO4J_CONNECTION_TIMEOUT", "60"))
  99. )
  100. def to_dict(self) -> Dict[str, Any]:
  101. """转换为字典"""
  102. return self.model_dump()
  103. class DatabaseConfig(BaseModel):
  104. """数据库配置管理器"""
  105. qdrant: QdrantConfig = Field(
  106. default_factory=QdrantConfig,
  107. description="Qdrant向量数据库配置"
  108. )
  109. neo4j: Neo4jConfig = Field(
  110. default_factory=Neo4jConfig,
  111. description="Neo4j图数据库配置"
  112. )
  113. @classmethod
  114. def from_env(cls) -> "DatabaseConfig":
  115. """从环境变量创建配置"""
  116. return cls(
  117. qdrant=QdrantConfig.from_env(),
  118. neo4j=Neo4jConfig.from_env()
  119. )
  120. def get_qdrant_config(self) -> Dict[str, Any]:
  121. """获取Qdrant配置字典"""
  122. return self.qdrant.to_dict()
  123. def get_neo4j_config(self) -> Dict[str, Any]:
  124. """获取Neo4j配置字典"""
  125. return self.neo4j.to_dict()
  126. def validate_connections(self) -> Dict[str, bool]:
  127. """验证数据库连接配置"""
  128. results = {}
  129. # 验证Qdrant配置
  130. try:
  131. from ..memory.storage.qdrant_store import QdrantVectorStore
  132. qdrant_store = QdrantVectorStore(**self.get_qdrant_config())
  133. results["qdrant"] = qdrant_store.health_check()
  134. logger.info(f"✅ Qdrant连接验证: {'成功' if results['qdrant'] else '失败'}")
  135. except Exception as e:
  136. results["qdrant"] = False
  137. logger.error(f"❌ Qdrant连接验证失败: {e}")
  138. # 验证Neo4j配置
  139. try:
  140. from ..memory.storage.neo4j_store import Neo4jGraphStore
  141. neo4j_store = Neo4jGraphStore(**self.get_neo4j_config())
  142. results["neo4j"] = neo4j_store.health_check()
  143. logger.info(f"✅ Neo4j连接验证: {'成功' if results['neo4j'] else '失败'}")
  144. except Exception as e:
  145. results["neo4j"] = False
  146. logger.error(f"❌ Neo4j连接验证失败: {e}")
  147. return results
  148. # 全局配置实例
  149. db_config = DatabaseConfig.from_env()
  150. def get_database_config() -> DatabaseConfig:
  151. """获取数据库配置"""
  152. return db_config
  153. def update_database_config(**kwargs) -> None:
  154. """更新数据库配置"""
  155. global db_config
  156. if "qdrant" in kwargs:
  157. db_config.qdrant = QdrantConfig(**kwargs["qdrant"])
  158. if "neo4j" in kwargs:
  159. db_config.neo4j = Neo4jConfig(**kwargs["neo4j"])
  160. logger.info("✅ 数据库配置已更新")