| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259 |
- """
- 智能搜索助手 - 基于 LangGraph + Tavily API 的真实搜索系统
- 1. 理解用户需求
- 2. 使用Tavily API真实搜索信息
- 3. 生成基于搜索结果的回答
- """
- import asyncio
- from typing import TypedDict, Annotated
- from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
- from langchain_openai import ChatOpenAI
- from langgraph.graph import StateGraph, START, END
- from langgraph.graph.message import add_messages
- from langgraph.checkpoint.memory import InMemorySaver
- import os
- from dotenv import load_dotenv
- from tavily import TavilyClient
- # 加载环境变量
- load_dotenv()
- # 定义状态结构
- class SearchState(TypedDict):
- messages: Annotated[list, add_messages]
- user_query: str # 用户查询
- search_query: str # 优化后的搜索查询
- search_results: str # Tavily搜索结果
- final_answer: str # 最终答案
- step: str # 当前步骤
- # 初始化模型和Tavily客户端
- llm = ChatOpenAI(
- model=os.getenv("LLM_MODEL_ID", "gpt-4o-mini"),
- api_key=os.getenv("LLM_API_KEY"),
- base_url=os.getenv("LLM_BASE_URL", "https://api.openai.com/v1"),
- temperature=0.7
- )
- # 初始化Tavily客户端
- tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
- def understand_query_node(state: SearchState) -> SearchState:
- """步骤1:理解用户查询并生成搜索关键词"""
-
- # 获取最新的用户消息
- user_message = ""
- for msg in reversed(state["messages"]):
- if isinstance(msg, HumanMessage):
- user_message = msg.content
- break
-
- understand_prompt = f"""分析用户的查询:"{user_message}"
- 请完成两个任务:
- 1. 简洁总结用户想要了解什么
- 2. 生成最适合搜索的关键词(中英文均可,要精准)
- 格式:
- 理解:[用户需求总结]
- 搜索词:[最佳搜索关键词]"""
- response = llm.invoke([SystemMessage(content=understand_prompt)])
-
- # 提取搜索关键词
- response_text = response.content
- search_query = user_message # 默认使用原始查询
-
- if "搜索词:" in response_text:
- search_query = response_text.split("搜索词:")[1].strip()
- elif "搜索关键词:" in response_text:
- search_query = response_text.split("搜索关键词:")[1].strip()
-
- return {
- "user_query": response.content,
- "search_query": search_query,
- "step": "understood",
- "messages": [AIMessage(content=f"我理解您的需求:{response.content}")]
- }
- def tavily_search_node(state: SearchState) -> SearchState:
- """步骤2:使用Tavily API进行真实搜索"""
-
- search_query = state["search_query"]
-
- try:
- print(f"🔍 正在搜索: {search_query}")
-
- # 调用Tavily搜索API
- response = tavily_client.search(
- query=search_query,
- search_depth="basic",
- include_answer=True,
- include_raw_content=False,
- max_results=5
- )
-
- # 处理搜索结果
- search_results = ""
-
- # 优先使用Tavily的综合答案
- if response.get("answer"):
- search_results = f"综合答案:\n{response['answer']}\n\n"
-
- # 添加具体的搜索结果
- if response.get("results"):
- search_results += "相关信息:\n"
- for i, result in enumerate(response["results"][:3], 1):
- title = result.get("title", "")
- content = result.get("content", "")
- url = result.get("url", "")
- search_results += f"{i}. {title}\n{content}\n来源:{url}\n\n"
-
- if not search_results:
- search_results = "抱歉,没有找到相关信息。"
-
- return {
- "search_results": search_results,
- "step": "searched",
- "messages": [AIMessage(content=f"✅ 搜索完成!找到了相关信息,正在为您整理答案...")]
- }
-
- except Exception as e:
- error_msg = f"搜索时发生错误: {str(e)}"
- print(f"❌ {error_msg}")
-
- return {
- "search_results": f"搜索失败:{error_msg}",
- "step": "search_failed",
- "messages": [AIMessage(content="❌ 搜索遇到问题,我将基于已有知识为您回答")]
- }
- def generate_answer_node(state: SearchState) -> SearchState:
- """步骤3:基于搜索结果生成最终答案"""
-
- # 检查是否有搜索结果
- if state["step"] == "search_failed":
- # 如果搜索失败,基于LLM知识回答
- fallback_prompt = f"""搜索API暂时不可用,请基于您的知识回答用户的问题:
- 用户问题:{state['user_query']}
- 请提供一个有用的回答,并说明这是基于已有知识的回答。"""
-
- response = llm.invoke([SystemMessage(content=fallback_prompt)])
-
- return {
- "final_answer": response.content,
- "step": "completed",
- "messages": [AIMessage(content=response.content)]
- }
-
- # 基于搜索结果生成答案
- answer_prompt = f"""基于以下搜索结果为用户提供完整、准确的答案:
- 用户问题:{state['user_query']}
- 搜索结果:
- {state['search_results']}
- 请要求:
- 1. 综合搜索结果,提供准确、有用的回答
- 2. 如果是技术问题,提供具体的解决方案或代码
- 3. 引用重要信息的来源
- 4. 回答要结构清晰、易于理解
- 5. 如果搜索结果不够完整,请说明并提供补充建议"""
- response = llm.invoke([SystemMessage(content=answer_prompt)])
-
- return {
- "final_answer": response.content,
- "step": "completed",
- "messages": [AIMessage(content=response.content)]
- }
- # 构建搜索工作流
- def create_search_assistant():
- workflow = StateGraph(SearchState)
-
- # 添加三个节点
- workflow.add_node("understand", understand_query_node)
- workflow.add_node("search", tavily_search_node)
- workflow.add_node("answer", generate_answer_node)
-
- # 设置线性流程
- workflow.add_edge(START, "understand")
- workflow.add_edge("understand", "search")
- workflow.add_edge("search", "answer")
- workflow.add_edge("answer", END)
-
- # 编译图
- memory = InMemorySaver()
- app = workflow.compile(checkpointer=memory)
-
- return app
- async def main():
- """主函数:运行智能搜索助手"""
-
- # 检查API密钥
- if not os.getenv("TAVILY_API_KEY"):
- print("❌ 错误:请在.env文件中配置TAVILY_API_KEY")
- return
-
- app = create_search_assistant()
-
- print("🔍 智能搜索助手启动!")
- print("我会使用Tavily API为您搜索最新、最准确的信息")
- print("支持各种问题:新闻、技术、知识问答等")
- print("(输入 'quit' 退出)\n")
-
- session_count = 0
-
- while True:
- user_input = input("🤔 您想了解什么: ").strip()
-
- if user_input.lower() in ['quit', 'q', '退出', 'exit']:
- print("感谢使用!再见!👋")
- break
-
- if not user_input:
- continue
-
- session_count += 1
- config = {"configurable": {"thread_id": f"search-session-{session_count}"}}
-
- # 初始状态
- initial_state = {
- "messages": [HumanMessage(content=user_input)],
- "user_query": "",
- "search_query": "",
- "search_results": "",
- "final_answer": "",
- "step": "start"
- }
-
- try:
- print("\n" + "="*60)
-
- # 执行工作流
- async for output in app.astream(initial_state, config=config):
- for node_name, node_output in output.items():
- if "messages" in node_output and node_output["messages"]:
- latest_message = node_output["messages"][-1]
- if isinstance(latest_message, AIMessage):
- if node_name == "understand":
- print(f"🧠 理解阶段: {latest_message.content}")
- elif node_name == "search":
- print(f"🔍 搜索阶段: {latest_message.content}")
- elif node_name == "answer":
- print(f"\n💡 最终回答:\n{latest_message.content}")
-
- print("\n" + "="*60 + "\n")
-
- except Exception as e:
- print(f"❌ 发生错误: {e}")
- print("请重新输入您的问题。\n")
- if __name__ == "__main__":
- asyncio.run(main())
|