Dialogue_System.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. """
  2. 智能搜索助手 - 基于 LangGraph + Tavily API 的真实搜索系统
  3. 1. 理解用户需求
  4. 2. 使用Tavily API真实搜索信息
  5. 3. 生成基于搜索结果的回答
  6. """
  7. import asyncio
  8. from typing import TypedDict, Annotated
  9. from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
  10. from langchain_openai import ChatOpenAI
  11. from langgraph.graph import StateGraph, START, END
  12. from langgraph.graph.message import add_messages
  13. from langgraph.checkpoint.memory import InMemorySaver
  14. import os
  15. from dotenv import load_dotenv
  16. from tavily import TavilyClient
  17. # 加载环境变量
  18. load_dotenv()
  19. # 定义状态结构
  20. class SearchState(TypedDict):
  21. messages: Annotated[list, add_messages]
  22. user_query: str # 用户查询
  23. search_query: str # 优化后的搜索查询
  24. search_results: str # Tavily搜索结果
  25. final_answer: str # 最终答案
  26. step: str # 当前步骤
  27. # 初始化模型和Tavily客户端
  28. llm = ChatOpenAI(
  29. model=os.getenv("LLM_MODEL_ID", "gpt-4o-mini"),
  30. api_key=os.getenv("LLM_API_KEY"),
  31. base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"),
  32. temperature=0.7
  33. )
  34. # 初始化Tavily客户端
  35. tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
  36. def understand_query_node(state: SearchState) -> SearchState:
  37. """步骤1:理解用户查询并生成搜索关键词"""
  38. # 获取最新的用户消息
  39. user_message = ""
  40. for msg in reversed(state["messages"]):
  41. if isinstance(msg, HumanMessage):
  42. user_message = msg.content
  43. break
  44. understand_prompt = f"""分析用户的查询:"{user_message}"
  45. 请完成两个任务:
  46. 1. 简洁总结用户想要了解什么
  47. 2. 生成最适合搜索的关键词(中英文均可,要精准)
  48. 格式:
  49. 理解:[用户需求总结]
  50. 搜索词:[最佳搜索关键词]"""
  51. response = llm.invoke([SystemMessage(content=understand_prompt)])
  52. # 提取搜索关键词
  53. response_text = response.content
  54. search_query = user_message # 默认使用原始查询
  55. if "搜索词:" in response_text:
  56. search_query = response_text.split("搜索词:")[1].strip()
  57. elif "搜索关键词:" in response_text:
  58. search_query = response_text.split("搜索关键词:")[1].strip()
  59. return {
  60. "user_query": response.content,
  61. "search_query": search_query,
  62. "step": "understood",
  63. "messages": [AIMessage(content=f"我理解您的需求:{response.content}")]
  64. }
  65. def tavily_search_node(state: SearchState) -> SearchState:
  66. """步骤2:使用Tavily API进行真实搜索"""
  67. search_query = state["search_query"]
  68. try:
  69. print(f"🔍 正在搜索: {search_query}")
  70. # 调用Tavily搜索API
  71. response = tavily_client.search(
  72. query=search_query,
  73. search_depth="basic",
  74. include_answer=True,
  75. include_raw_content=False,
  76. max_results=5
  77. )
  78. # 处理搜索结果
  79. search_results = ""
  80. # 优先使用Tavily的综合答案
  81. if response.get("answer"):
  82. search_results = f"综合答案:\n{response['answer']}\n\n"
  83. # 添加具体的搜索结果
  84. if response.get("results"):
  85. search_results += "相关信息:\n"
  86. for i, result in enumerate(response["results"][:3], 1):
  87. title = result.get("title", "")
  88. content = result.get("content", "")
  89. url = result.get("url", "")
  90. search_results += f"{i}. {title}\n{content}\n来源:{url}\n\n"
  91. if not search_results:
  92. search_results = "抱歉,没有找到相关信息。"
  93. return {
  94. "search_results": search_results,
  95. "step": "searched",
  96. "messages": [AIMessage(content=f"✅ 搜索完成!找到了相关信息,正在为您整理答案...")]
  97. }
  98. except Exception as e:
  99. error_msg = f"搜索时发生错误: {str(e)}"
  100. print(f"❌ {error_msg}")
  101. return {
  102. "search_results": f"搜索失败:{error_msg}",
  103. "step": "search_failed",
  104. "messages": [AIMessage(content="❌ 搜索遇到问题,我将基于已有知识为您回答")]
  105. }
  106. def generate_answer_node(state: SearchState) -> SearchState:
  107. """步骤3:基于搜索结果生成最终答案"""
  108. # 检查是否有搜索结果
  109. if state["step"] == "search_failed":
  110. # 如果搜索失败,基于LLM知识回答
  111. fallback_prompt = f"""搜索API暂时不可用,请基于您的知识回答用户的问题:
  112. 用户问题:{state['user_query']}
  113. 请提供一个有用的回答,并说明这是基于已有知识的回答。"""
  114. response = llm.invoke([SystemMessage(content=fallback_prompt)])
  115. return {
  116. "final_answer": response.content,
  117. "step": "completed",
  118. "messages": [AIMessage(content=response.content)]
  119. }
  120. # 基于搜索结果生成答案
  121. answer_prompt = f"""基于以下搜索结果为用户提供完整、准确的答案:
  122. 用户问题:{state['user_query']}
  123. 搜索结果:
  124. {state['search_results']}
  125. 请要求:
  126. 1. 综合搜索结果,提供准确、有用的回答
  127. 2. 如果是技术问题,提供具体的解决方案或代码
  128. 3. 引用重要信息的来源
  129. 4. 回答要结构清晰、易于理解
  130. 5. 如果搜索结果不够完整,请说明并提供补充建议"""
  131. response = llm.invoke([SystemMessage(content=answer_prompt)])
  132. return {
  133. "final_answer": response.content,
  134. "step": "completed",
  135. "messages": [AIMessage(content=response.content)]
  136. }
  137. # 构建搜索工作流
  138. def create_search_assistant():
  139. workflow = StateGraph(SearchState)
  140. # 添加三个节点
  141. workflow.add_node("understand", understand_query_node)
  142. workflow.add_node("search", tavily_search_node)
  143. workflow.add_node("answer", generate_answer_node)
  144. # 设置线性流程
  145. workflow.add_edge(START, "understand")
  146. workflow.add_edge("understand", "search")
  147. workflow.add_edge("search", "answer")
  148. workflow.add_edge("answer", END)
  149. # 编译图
  150. memory = InMemorySaver()
  151. app = workflow.compile(checkpointer=memory)
  152. return app
  153. async def main():
  154. """主函数:运行智能搜索助手"""
  155. # 检查API密钥
  156. if not os.getenv("TAVILY_API_KEY"):
  157. print("❌ 错误:请在.env文件中配置TAVILY_API_KEY")
  158. return
  159. app = create_search_assistant()
  160. print("🔍 智能搜索助手启动!")
  161. print("我会使用Tavily API为您搜索最新、最准确的信息")
  162. print("支持各种问题:新闻、技术、知识问答等")
  163. print("(输入 'quit' 退出)\n")
  164. session_count = 0
  165. while True:
  166. user_input = input("🤔 您想了解什么: ").strip()
  167. if user_input.lower() in ['quit', 'q', '退出', 'exit']:
  168. print("感谢使用!再见!👋")
  169. break
  170. if not user_input:
  171. continue
  172. session_count += 1
  173. config = {"configurable": {"thread_id": f"search-session-{session_count}"}}
  174. # 初始状态
  175. initial_state = {
  176. "messages": [HumanMessage(content=user_input)],
  177. "user_query": "",
  178. "search_query": "",
  179. "search_results": "",
  180. "final_answer": "",
  181. "step": "start"
  182. }
  183. try:
  184. print("\n" + "="*60)
  185. # 执行工作流
  186. async for output in app.astream(initial_state, config=config):
  187. for node_name, node_output in output.items():
  188. if "messages" in node_output and node_output["messages"]:
  189. latest_message = node_output["messages"][-1]
  190. if isinstance(latest_message, AIMessage):
  191. if node_name == "understand":
  192. print(f"🧠 理解阶段: {latest_message.content}")
  193. elif node_name == "search":
  194. print(f"🔍 搜索阶段: {latest_message.content}")
  195. elif node_name == "answer":
  196. print(f"\n💡 最终回答:\n{latest_message.content}")
  197. print("\n" + "="*60 + "\n")
  198. except Exception as e:
  199. print(f"❌ 发生错误: {e}")
  200. print("请重新输入您的问题。\n")
  201. if __name__ == "__main__":
  202. asyncio.run(main())