neo4j_store.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. """
  2. Neo4j图数据库存储实现
  3. """
  4. import logging
  5. from typing import Dict, List, Optional, Any, Tuple
  6. from datetime import datetime
  7. try:
  8. from neo4j import GraphDatabase
  9. from neo4j.exceptions import ServiceUnavailable, AuthError
  10. NEO4J_AVAILABLE = True
  11. except ImportError:
  12. NEO4J_AVAILABLE = False
  13. GraphDatabase = None
  14. logger = logging.getLogger(__name__)
  15. class Neo4jGraphStore:
  16. """Neo4j图数据库存储实现"""
  17. def __init__(
  18. self,
  19. uri: str = "bolt://localhost:7687",
  20. username: str = "neo4j",
  21. password: str = "hello-agents-password",
  22. database: str = "neo4j",
  23. max_connection_lifetime: int = 3600,
  24. max_connection_pool_size: int = 50,
  25. connection_acquisition_timeout: int = 60,
  26. **kwargs
  27. ):
  28. """
  29. 初始化Neo4j图存储 (支持云API)
  30. Args:
  31. uri: Neo4j连接URI (本地: bolt://localhost:7687, 云: neo4j+s://xxx.databases.neo4j.io)
  32. username: 用户名
  33. password: 密码
  34. database: 数据库名称
  35. max_connection_lifetime: 最大连接生命周期(秒)
  36. max_connection_pool_size: 最大连接池大小
  37. connection_acquisition_timeout: 连接获取超时(秒)
  38. """
  39. if not NEO4J_AVAILABLE:
  40. raise ImportError(
  41. "neo4j未安装。请运行: pip install neo4j>=5.0.0"
  42. )
  43. self.uri = uri
  44. self.username = username
  45. self.password = password
  46. self.database = database
  47. # 初始化驱动
  48. self.driver = None
  49. self._initialize_driver(
  50. max_connection_lifetime=max_connection_lifetime,
  51. max_connection_pool_size=max_connection_pool_size,
  52. connection_acquisition_timeout=connection_acquisition_timeout
  53. )
  54. # 创建索引
  55. self._create_indexes()
  56. def _initialize_driver(self, **config):
  57. """初始化Neo4j驱动"""
  58. try:
  59. self.driver = GraphDatabase.driver(
  60. self.uri,
  61. auth=(self.username, self.password),
  62. **config
  63. )
  64. # 验证连接
  65. self.driver.verify_connectivity()
  66. # 检查是否是云服务
  67. if "neo4j.io" in self.uri or "aura" in self.uri.lower():
  68. logger.info(f"✅ 成功连接到Neo4j云服务: {self.uri}")
  69. else:
  70. logger.info(f"✅ 成功连接到Neo4j服务: {self.uri}")
  71. except AuthError as e:
  72. logger.error(f"❌ Neo4j认证失败: {e}")
  73. logger.info("💡 请检查用户名和密码是否正确")
  74. raise
  75. except ServiceUnavailable as e:
  76. logger.error(f"❌ Neo4j服务不可用: {e}")
  77. if "localhost" in self.uri:
  78. logger.info("💡 本地连接失败,可以考虑使用Neo4j Aura云服务")
  79. logger.info("💡 或启动本地服务: docker run -p 7474:7474 -p 7687:7687 neo4j:5.14")
  80. else:
  81. logger.info("💡 请检查URL和网络连接")
  82. raise
  83. except Exception as e:
  84. logger.error(f"❌ Neo4j连接失败: {e}")
  85. raise
  86. def _create_indexes(self):
  87. """创建必要的索引以提高查询性能"""
  88. indexes = [
  89. # 实体索引
  90. "CREATE INDEX entity_id_index IF NOT EXISTS FOR (e:Entity) ON (e.id)",
  91. "CREATE INDEX entity_name_index IF NOT EXISTS FOR (e:Entity) ON (e.name)",
  92. "CREATE INDEX entity_type_index IF NOT EXISTS FOR (e:Entity) ON (e.type)",
  93. # 记忆索引
  94. "CREATE INDEX memory_id_index IF NOT EXISTS FOR (m:Memory) ON (m.id)",
  95. "CREATE INDEX memory_type_index IF NOT EXISTS FOR (m:Memory) ON (m.memory_type)",
  96. "CREATE INDEX memory_timestamp_index IF NOT EXISTS FOR (m:Memory) ON (m.timestamp)",
  97. ]
  98. with self.driver.session(database=self.database) as session:
  99. for index_query in indexes:
  100. try:
  101. session.run(index_query)
  102. except Exception as e:
  103. logger.debug(f"索引创建跳过 (可能已存在): {e}")
  104. logger.info("✅ Neo4j索引创建完成")
  105. def add_entity(self, entity_id: str, name: str, entity_type: str, properties: Dict[str, Any] = None) -> bool:
  106. """
  107. 添加实体节点
  108. Args:
  109. entity_id: 实体ID
  110. name: 实体名称
  111. entity_type: 实体类型
  112. properties: 附加属性
  113. Returns:
  114. bool: 是否成功
  115. """
  116. try:
  117. props = properties or {}
  118. props.update({
  119. "id": entity_id,
  120. "name": name,
  121. "type": entity_type,
  122. "created_at": datetime.now().isoformat(),
  123. "updated_at": datetime.now().isoformat()
  124. })
  125. query = """
  126. MERGE (e:Entity {id: $entity_id})
  127. SET e += $properties
  128. RETURN e
  129. """
  130. with self.driver.session(database=self.database) as session:
  131. result = session.run(query, entity_id=entity_id, properties=props)
  132. record = result.single()
  133. if record:
  134. logger.debug(f"✅ 添加实体: {name} ({entity_type})")
  135. return True
  136. return False
  137. except Exception as e:
  138. logger.error(f"❌ 添加实体失败: {e}")
  139. return False
  140. def add_relationship(
  141. self,
  142. from_entity_id: str,
  143. to_entity_id: str,
  144. relationship_type: str,
  145. properties: Dict[str, Any] = None
  146. ) -> bool:
  147. """
  148. 添加实体间关系
  149. Args:
  150. from_entity_id: 源实体ID
  151. to_entity_id: 目标实体ID
  152. relationship_type: 关系类型
  153. properties: 关系属性
  154. Returns:
  155. bool: 是否成功
  156. """
  157. try:
  158. props = properties or {}
  159. props.update({
  160. "type": relationship_type,
  161. "created_at": datetime.now().isoformat(),
  162. "updated_at": datetime.now().isoformat()
  163. })
  164. query = f"""
  165. MATCH (from:Entity {{id: $from_id}})
  166. MATCH (to:Entity {{id: $to_id}})
  167. MERGE (from)-[r:{relationship_type}]->(to)
  168. SET r += $properties
  169. RETURN r
  170. """
  171. with self.driver.session(database=self.database) as session:
  172. result = session.run(
  173. query,
  174. from_id=from_entity_id,
  175. to_id=to_entity_id,
  176. properties=props
  177. )
  178. record = result.single()
  179. if record:
  180. logger.debug(f"✅ 添加关系: {from_entity_id} -{relationship_type}-> {to_entity_id}")
  181. return True
  182. return False
  183. except Exception as e:
  184. logger.error(f"❌ 添加关系失败: {e}")
  185. return False
  186. def find_related_entities(
  187. self,
  188. entity_id: str,
  189. relationship_types: List[str] = None,
  190. max_depth: int = 2,
  191. limit: int = 50
  192. ) -> List[Dict[str, Any]]:
  193. """
  194. 查找相关实体
  195. Args:
  196. entity_id: 起始实体ID
  197. relationship_types: 关系类型过滤
  198. max_depth: 最大搜索深度
  199. limit: 结果限制
  200. Returns:
  201. List[Dict]: 相关实体列表
  202. """
  203. try:
  204. # 构建关系类型过滤
  205. rel_filter = ""
  206. if relationship_types:
  207. rel_types = "|".join(relationship_types)
  208. rel_filter = f":{rel_types}"
  209. query = f"""
  210. MATCH path = (start:Entity {{id: $entity_id}})-[r{rel_filter}*1..{max_depth}]-(related:Entity)
  211. WHERE start.id <> related.id
  212. RETURN DISTINCT related,
  213. length(path) as distance,
  214. [rel in relationships(path) | type(rel)] as relationship_path
  215. ORDER BY distance, related.name
  216. LIMIT $limit
  217. """
  218. with self.driver.session(database=self.database) as session:
  219. result = session.run(query, entity_id=entity_id, limit=limit)
  220. entities = []
  221. for record in result:
  222. entity_data = dict(record["related"])
  223. entity_data["distance"] = record["distance"]
  224. entity_data["relationship_path"] = record["relationship_path"]
  225. entities.append(entity_data)
  226. logger.debug(f"🔍 找到 {len(entities)} 个相关实体")
  227. return entities
  228. except Exception as e:
  229. logger.error(f"❌ 查找相关实体失败: {e}")
  230. return []
  231. def search_entities_by_name(self, name_pattern: str, entity_types: List[str] = None, limit: int = 20) -> List[Dict[str, Any]]:
  232. """
  233. 按名称搜索实体
  234. Args:
  235. name_pattern: 名称模式 (支持部分匹配)
  236. entity_types: 实体类型过滤
  237. limit: 结果限制
  238. Returns:
  239. List[Dict]: 匹配的实体列表
  240. """
  241. try:
  242. # 构建类型过滤
  243. type_filter = ""
  244. params = {"pattern": f".*{name_pattern}.*", "limit": limit}
  245. if entity_types:
  246. type_filter = "AND e.type IN $types"
  247. params["types"] = entity_types
  248. query = f"""
  249. MATCH (e:Entity)
  250. WHERE e.name =~ $pattern {type_filter}
  251. RETURN e
  252. ORDER BY e.name
  253. LIMIT $limit
  254. """
  255. with self.driver.session(database=self.database) as session:
  256. result = session.run(query, **params)
  257. entities = []
  258. for record in result:
  259. entity_data = dict(record["e"])
  260. entities.append(entity_data)
  261. logger.debug(f"🔍 按名称搜索到 {len(entities)} 个实体")
  262. return entities
  263. except Exception as e:
  264. logger.error(f"❌ 按名称搜索实体失败: {e}")
  265. return []
  266. def get_entity_relationships(self, entity_id: str) -> List[Dict[str, Any]]:
  267. """
  268. 获取实体的所有关系
  269. Args:
  270. entity_id: 实体ID
  271. Returns:
  272. List[Dict]: 关系列表
  273. """
  274. try:
  275. query = """
  276. MATCH (e:Entity {id: $entity_id})-[r]-(other:Entity)
  277. RETURN r, other,
  278. CASE WHEN startNode(r).id = $entity_id THEN 'outgoing' ELSE 'incoming' END as direction
  279. """
  280. with self.driver.session(database=self.database) as session:
  281. result = session.run(query, entity_id=entity_id)
  282. relationships = []
  283. for record in result:
  284. rel_data = dict(record["r"])
  285. other_data = dict(record["other"])
  286. relationship = {
  287. "relationship": rel_data,
  288. "other_entity": other_data,
  289. "direction": record["direction"]
  290. }
  291. relationships.append(relationship)
  292. return relationships
  293. except Exception as e:
  294. logger.error(f"❌ 获取实体关系失败: {e}")
  295. return []
  296. def delete_entity(self, entity_id: str) -> bool:
  297. """
  298. 删除实体及其所有关系
  299. Args:
  300. entity_id: 实体ID
  301. Returns:
  302. bool: 是否成功
  303. """
  304. try:
  305. query = """
  306. MATCH (e:Entity {id: $entity_id})
  307. DETACH DELETE e
  308. """
  309. with self.driver.session(database=self.database) as session:
  310. result = session.run(query, entity_id=entity_id)
  311. summary = result.consume()
  312. deleted_count = summary.counters.nodes_deleted
  313. logger.info(f"✅ 删除实体: {entity_id} (删除 {deleted_count} 个节点)")
  314. return deleted_count > 0
  315. except Exception as e:
  316. logger.error(f"❌ 删除实体失败: {e}")
  317. return False
  318. def clear_all(self) -> bool:
  319. """
  320. 清空所有数据
  321. Returns:
  322. bool: 是否成功
  323. """
  324. try:
  325. query = "MATCH (n) DETACH DELETE n"
  326. with self.driver.session(database=self.database) as session:
  327. result = session.run(query)
  328. summary = result.consume()
  329. deleted_nodes = summary.counters.nodes_deleted
  330. deleted_relationships = summary.counters.relationships_deleted
  331. logger.info(f"✅ 清空Neo4j数据库: 删除 {deleted_nodes} 个节点, {deleted_relationships} 个关系")
  332. return True
  333. except Exception as e:
  334. logger.error(f"❌ 清空数据库失败: {e}")
  335. return False
  336. def get_stats(self) -> Dict[str, Any]:
  337. """
  338. 获取图数据库统计信息
  339. Returns:
  340. Dict: 统计信息
  341. """
  342. try:
  343. queries = {
  344. "total_nodes": "MATCH (n) RETURN count(n) as count",
  345. "total_relationships": "MATCH ()-[r]->() RETURN count(r) as count",
  346. "entity_nodes": "MATCH (n:Entity) RETURN count(n) as count",
  347. "memory_nodes": "MATCH (n:Memory) RETURN count(n) as count",
  348. }
  349. stats = {}
  350. with self.driver.session(database=self.database) as session:
  351. for key, query in queries.items():
  352. result = session.run(query)
  353. record = result.single()
  354. stats[key] = record["count"] if record else 0
  355. return stats
  356. except Exception as e:
  357. logger.error(f"❌ 获取统计信息失败: {e}")
  358. return {}
  359. def health_check(self) -> bool:
  360. """
  361. 健康检查
  362. Returns:
  363. bool: 服务是否健康
  364. """
  365. try:
  366. with self.driver.session(database=self.database) as session:
  367. result = session.run("RETURN 1 as health")
  368. record = result.single()
  369. return record["health"] == 1
  370. except Exception as e:
  371. logger.error(f"❌ Neo4j健康检查失败: {e}")
  372. return False
  373. def __del__(self):
  374. """析构函数,清理资源"""
  375. if hasattr(self, 'driver') and self.driver:
  376. try:
  377. self.driver.close()
  378. except:
  379. pass