agent.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. """Orchestrator coordinating the deep research workflow."""
  2. from __future__ import annotations
  3. import logging
  4. import re
  5. from pathlib import Path
  6. from queue import Empty, Queue
  7. from threading import Lock, Thread
  8. from typing import Any, Callable, Iterator
  9. from hello_agents import HelloAgentsLLM, ToolAwareSimpleAgent
  10. from hello_agents.tools import ToolRegistry
  11. from hello_agents.tools.builtin.note_tool import NoteTool
  12. from config import Configuration
  13. from prompts import (
  14. report_writer_instructions,
  15. task_summarizer_instructions,
  16. todo_planner_system_prompt,
  17. )
  18. from models import SummaryState, SummaryStateOutput, TodoItem
  19. from services.planner import PlanningService
  20. from services.reporter import ReportingService
  21. from services.search import dispatch_search, prepare_research_context
  22. from services.summarizer import SummarizationService
  23. from services.tool_events import ToolCallTracker
  24. logger = logging.getLogger(__name__)
  25. class DeepResearchAgent:
  26. """Coordinator orchestrating TODO-based research workflow using HelloAgents."""
  27. def __init__(self, config: Configuration | None = None) -> None:
  28. """Initialise the coordinator with configuration and shared tools."""
  29. self.config = config or Configuration.from_env()
  30. self.llm = self._init_llm()
  31. self.note_tool = (
  32. NoteTool(workspace=self.config.notes_workspace)
  33. if self.config.enable_notes
  34. else None
  35. )
  36. self.tools_registry: ToolRegistry | None = None
  37. if self.note_tool:
  38. registry = ToolRegistry()
  39. registry.register_tool(self.note_tool)
  40. self.tools_registry = registry
  41. self._tool_tracker = ToolCallTracker(
  42. self.config.notes_workspace if self.config.enable_notes else None
  43. )
  44. self._tool_event_sink_enabled = False
  45. self._state_lock = Lock()
  46. self.todo_agent = self._create_tool_aware_agent(
  47. name="研究规划专家",
  48. system_prompt=todo_planner_system_prompt.strip(),
  49. )
  50. self.report_agent = self._create_tool_aware_agent(
  51. name="报告撰写专家",
  52. system_prompt=report_writer_instructions.strip(),
  53. )
  54. self._summarizer_factory: Callable[[], ToolAwareSimpleAgent] = lambda: self._create_tool_aware_agent( # noqa: E501
  55. name="任务总结专家",
  56. system_prompt=task_summarizer_instructions.strip(),
  57. )
  58. self.planner = PlanningService(self.todo_agent, self.config)
  59. self.summarizer = SummarizationService(self._summarizer_factory, self.config)
  60. self.reporting = ReportingService(self.report_agent, self.config)
  61. self._last_search_notices: list[str] = []
  62. # ------------------------------------------------------------------
  63. # Public API
  64. # ------------------------------------------------------------------
  65. def _init_llm(self) -> HelloAgentsLLM:
  66. """Instantiate HelloAgentsLLM following configuration preferences."""
  67. llm_kwargs: dict[str, Any] = {"temperature": 0.0}
  68. model_id = self.config.llm_model_id or self.config.local_llm
  69. if model_id:
  70. llm_kwargs["model"] = model_id
  71. provider = (self.config.llm_provider or "").strip()
  72. if provider:
  73. llm_kwargs["provider"] = provider
  74. if provider == "ollama":
  75. llm_kwargs["base_url"] = self.config.sanitized_ollama_url()
  76. if self.config.llm_api_key:
  77. llm_kwargs["api_key"] = self.config.llm_api_key
  78. else:
  79. llm_kwargs["api_key"] = "ollama"
  80. elif provider == "lmstudio":
  81. llm_kwargs["base_url"] = self.config.lmstudio_base_url
  82. if self.config.llm_api_key:
  83. llm_kwargs["api_key"] = self.config.llm_api_key
  84. else:
  85. if self.config.llm_base_url:
  86. llm_kwargs["base_url"] = self.config.llm_base_url
  87. if self.config.llm_api_key:
  88. llm_kwargs["api_key"] = self.config.llm_api_key
  89. return HelloAgentsLLM(**llm_kwargs)
  90. def _create_tool_aware_agent(self, *, name: str, system_prompt: str) -> ToolAwareSimpleAgent:
  91. """Instantiate a ToolAwareSimpleAgent sharing tool registry and tracker."""
  92. return ToolAwareSimpleAgent(
  93. name=name,
  94. llm=self.llm,
  95. system_prompt=system_prompt,
  96. enable_tool_calling=self.tools_registry is not None,
  97. tool_registry=self.tools_registry,
  98. tool_call_listener=self._tool_tracker.record,
  99. )
  100. def _set_tool_event_sink(self, sink: Callable[[dict[str, Any]], None] | None) -> None:
  101. """Enable or disable immediate tool event callbacks."""
  102. self._tool_event_sink_enabled = sink is not None
  103. self._tool_tracker.set_event_sink(sink)
  104. def run(self, topic: str) -> SummaryStateOutput:
  105. """Execute the research workflow and return the final report."""
  106. state = SummaryState(research_topic=topic)
  107. state.todo_items = self.planner.plan_todo_list(state)
  108. self._drain_tool_events(state)
  109. if not state.todo_items:
  110. logger.info("No TODO items generated; falling back to single task")
  111. state.todo_items = [self.planner.create_fallback_task(state)]
  112. for task in state.todo_items:
  113. self._execute_task(state, task, emit_stream=False)
  114. report = self.reporting.generate_report(state)
  115. self._drain_tool_events(state)
  116. state.structured_report = report
  117. state.running_summary = report
  118. self._persist_final_report(state, report)
  119. return SummaryStateOutput(
  120. running_summary=report,
  121. report_markdown=report,
  122. todo_items=state.todo_items,
  123. )
  124. def run_stream(self, topic: str) -> Iterator[dict[str, Any]]:
  125. """Execute the workflow yielding incremental progress events."""
  126. state = SummaryState(research_topic=topic)
  127. logger.debug("Starting streaming research: topic=%s", topic)
  128. yield {"type": "status", "message": "初始化研究流程"}
  129. state.todo_items = self.planner.plan_todo_list(state)
  130. for event in self._drain_tool_events(state, step=0):
  131. yield event
  132. if not state.todo_items:
  133. state.todo_items = [self.planner.create_fallback_task(state)]
  134. channel_map: dict[int, dict[str, Any]] = {}
  135. for index, task in enumerate(state.todo_items, start=1):
  136. token = f"task_{task.id}"
  137. task.stream_token = token
  138. channel_map[task.id] = {"step": index, "token": token}
  139. yield {
  140. "type": "todo_list",
  141. "tasks": [self._serialize_task(t) for t in state.todo_items],
  142. "step": 0,
  143. }
  144. event_queue: Queue[dict[str, Any]] = Queue()
  145. def enqueue(
  146. event: dict[str, Any],
  147. *,
  148. task: TodoItem | None = None,
  149. step_override: int | None = None,
  150. ) -> None:
  151. payload = dict(event)
  152. target_task_id = payload.get("task_id")
  153. if task is not None:
  154. target_task_id = task.id
  155. payload["task_id"] = task.id
  156. channel = channel_map.get(target_task_id) if target_task_id is not None else None
  157. if channel:
  158. payload.setdefault("step", channel["step"])
  159. payload["stream_token"] = channel["token"]
  160. if step_override is not None:
  161. payload["step"] = step_override
  162. event_queue.put(payload)
  163. def tool_event_sink(event: dict[str, Any]) -> None:
  164. enqueue(event)
  165. self._set_tool_event_sink(tool_event_sink)
  166. threads: list[Thread] = []
  167. def worker(task: TodoItem, step: int) -> None:
  168. try:
  169. enqueue(
  170. {
  171. "type": "task_status",
  172. "task_id": task.id,
  173. "status": "in_progress",
  174. "title": task.title,
  175. "intent": task.intent,
  176. "note_id": task.note_id,
  177. "note_path": task.note_path,
  178. },
  179. task=task,
  180. )
  181. for event in self._execute_task(state, task, emit_stream=True, step=step):
  182. enqueue(event, task=task)
  183. except Exception as exc: # pragma: no cover - defensive guardrail
  184. logger.exception("Task execution failed", exc_info=exc)
  185. enqueue(
  186. {
  187. "type": "task_status",
  188. "task_id": task.id,
  189. "status": "failed",
  190. "detail": str(exc),
  191. "title": task.title,
  192. "intent": task.intent,
  193. "note_id": task.note_id,
  194. "note_path": task.note_path,
  195. },
  196. task=task,
  197. )
  198. finally:
  199. enqueue({"type": "__task_done__", "task_id": task.id})
  200. for task in state.todo_items:
  201. step = channel_map.get(task.id, {}).get("step", 0)
  202. thread = Thread(target=worker, args=(task, step), daemon=True)
  203. threads.append(thread)
  204. thread.start()
  205. active_workers = len(state.todo_items)
  206. finished_workers = 0
  207. try:
  208. while finished_workers < active_workers:
  209. event = event_queue.get()
  210. if event.get("type") == "__task_done__":
  211. finished_workers += 1
  212. continue
  213. yield event
  214. while True:
  215. try:
  216. event = event_queue.get_nowait()
  217. except Empty:
  218. break
  219. if event.get("type") != "__task_done__":
  220. yield event
  221. finally:
  222. self._set_tool_event_sink(None)
  223. for thread in threads:
  224. thread.join()
  225. report = self.reporting.generate_report(state)
  226. final_step = len(state.todo_items) + 1
  227. for event in self._drain_tool_events(state, step=final_step):
  228. yield event
  229. state.structured_report = report
  230. state.running_summary = report
  231. note_event = self._persist_final_report(state, report)
  232. if note_event:
  233. yield note_event
  234. yield {
  235. "type": "final_report",
  236. "report": report,
  237. "note_id": state.report_note_id,
  238. "note_path": state.report_note_path,
  239. }
  240. yield {"type": "done"}
  241. # ------------------------------------------------------------------
  242. # Execution helpers
  243. # ------------------------------------------------------------------
  244. def _execute_task(
  245. self,
  246. state: SummaryState,
  247. task: TodoItem,
  248. *,
  249. emit_stream: bool,
  250. step: int | None = None,
  251. ) -> Iterator[dict[str, Any]]:
  252. """Run search + summarization for a single task."""
  253. task.status = "in_progress"
  254. search_result, notices, answer_text, backend = dispatch_search(
  255. task.query,
  256. self.config,
  257. state.research_loop_count,
  258. )
  259. self._last_search_notices = notices
  260. task.notices = notices
  261. if emit_stream:
  262. for event in self._drain_tool_events(state, step=step):
  263. yield event
  264. else:
  265. self._drain_tool_events(state)
  266. if notices and emit_stream:
  267. for notice in notices:
  268. if notice:
  269. yield {
  270. "type": "status",
  271. "message": notice,
  272. "task_id": task.id,
  273. "step": step,
  274. }
  275. if not search_result or not search_result.get("results"):
  276. task.status = "skipped"
  277. if emit_stream:
  278. for event in self._drain_tool_events(state, step=step):
  279. yield event
  280. yield {
  281. "type": "task_status",
  282. "task_id": task.id,
  283. "status": "skipped",
  284. "title": task.title,
  285. "intent": task.intent,
  286. "note_id": task.note_id,
  287. "note_path": task.note_path,
  288. "step": step,
  289. }
  290. else:
  291. self._drain_tool_events(state)
  292. return
  293. else:
  294. if not emit_stream:
  295. self._drain_tool_events(state)
  296. sources_summary, context = prepare_research_context(
  297. search_result,
  298. answer_text,
  299. self.config,
  300. )
  301. task.sources_summary = sources_summary
  302. with self._state_lock:
  303. state.web_research_results.append(context)
  304. state.sources_gathered.append(sources_summary)
  305. state.research_loop_count += 1
  306. summary_text: str | None = None
  307. if emit_stream:
  308. for event in self._drain_tool_events(state, step=step):
  309. yield event
  310. yield {
  311. "type": "sources",
  312. "task_id": task.id,
  313. "latest_sources": sources_summary,
  314. "raw_context": context,
  315. "step": step,
  316. "backend": backend,
  317. "note_id": task.note_id,
  318. "note_path": task.note_path,
  319. }
  320. summary_stream, summary_getter = self.summarizer.stream_task_summary(state, task, context)
  321. try:
  322. for event in self._drain_tool_events(state, step=step):
  323. yield event
  324. for chunk in summary_stream:
  325. if chunk:
  326. yield {
  327. "type": "task_summary_chunk",
  328. "task_id": task.id,
  329. "content": chunk,
  330. "note_id": task.note_id,
  331. "step": step,
  332. }
  333. for event in self._drain_tool_events(state, step=step):
  334. yield event
  335. finally:
  336. summary_text = summary_getter()
  337. else:
  338. summary_text = self.summarizer.summarize_task(state, task, context)
  339. self._drain_tool_events(state)
  340. task.summary = summary_text.strip() if summary_text else "暂无可用信息"
  341. task.status = "completed"
  342. if emit_stream:
  343. for event in self._drain_tool_events(state, step=step):
  344. yield event
  345. yield {
  346. "type": "task_status",
  347. "task_id": task.id,
  348. "status": "completed",
  349. "summary": task.summary,
  350. "sources_summary": task.sources_summary,
  351. "note_id": task.note_id,
  352. "note_path": task.note_path,
  353. "step": step,
  354. }
  355. else:
  356. self._drain_tool_events(state)
  357. def _drain_tool_events(
  358. self,
  359. state: SummaryState,
  360. *,
  361. step: int | None = None,
  362. ) -> list[dict[str, Any]]:
  363. """Proxy to the shared tool call tracker."""
  364. events = self._tool_tracker.drain(state, step=step)
  365. if self._tool_event_sink_enabled:
  366. return []
  367. return events
  368. @property
  369. def _tool_call_events(self) -> list[dict[str, Any]]:
  370. """Expose recorded tool events for legacy integrations."""
  371. return self._tool_tracker.as_dicts()
  372. def _serialize_task(self, task: TodoItem) -> dict[str, Any]:
  373. """Convert task dataclass to serializable dict for frontend."""
  374. return {
  375. "id": task.id,
  376. "title": task.title,
  377. "intent": task.intent,
  378. "query": task.query,
  379. "status": task.status,
  380. "summary": task.summary,
  381. "sources_summary": task.sources_summary,
  382. "note_id": task.note_id,
  383. "note_path": task.note_path,
  384. "stream_token": task.stream_token,
  385. }
  386. def _persist_final_report(self, state: SummaryState, report: str) -> dict[str, Any] | None:
  387. if not self.note_tool or not report or not report.strip():
  388. return None
  389. note_title = f"研究报告:{state.research_topic}".strip() or "研究报告"
  390. tags = ["deep_research", "report"]
  391. content = report.strip()
  392. note_id = self._find_existing_report_note_id(state)
  393. response = ""
  394. if note_id:
  395. response = self.note_tool.run(
  396. {
  397. "action": "update",
  398. "note_id": note_id,
  399. "title": note_title,
  400. "note_type": "conclusion",
  401. "tags": tags,
  402. "content": content,
  403. }
  404. )
  405. if response.startswith("❌"):
  406. note_id = None
  407. if not note_id:
  408. response = self.note_tool.run(
  409. {
  410. "action": "create",
  411. "title": note_title,
  412. "note_type": "conclusion",
  413. "tags": tags,
  414. "content": content,
  415. }
  416. )
  417. note_id = self._extract_note_id_from_text(response)
  418. if not note_id:
  419. return None
  420. state.report_note_id = note_id
  421. if self.config.notes_workspace:
  422. note_path = Path(self.config.notes_workspace) / f"{note_id}.md"
  423. state.report_note_path = str(note_path)
  424. else:
  425. note_path = None
  426. payload = {
  427. "type": "report_note",
  428. "note_id": note_id,
  429. "title": note_title,
  430. "content": content,
  431. }
  432. if note_path:
  433. payload["note_path"] = str(note_path)
  434. return payload
  435. def _find_existing_report_note_id(self, state: SummaryState) -> str | None:
  436. if state.report_note_id:
  437. return state.report_note_id
  438. for event in reversed(self._tool_tracker.as_dicts()):
  439. if event.get("tool") != "note":
  440. continue
  441. parameters = event.get("parsed_parameters") or {}
  442. if not isinstance(parameters, dict):
  443. continue
  444. action = parameters.get("action")
  445. if action not in {"create", "update"}:
  446. continue
  447. note_type = parameters.get("note_type")
  448. if note_type != "conclusion":
  449. title = parameters.get("title")
  450. if not (isinstance(title, str) and title.startswith("研究报告")):
  451. continue
  452. note_id = parameters.get("note_id")
  453. if not note_id:
  454. note_id = self._tool_tracker._extract_note_id(event.get("result", "")) # type: ignore[attr-defined]
  455. if note_id:
  456. return note_id
  457. return None
  458. @staticmethod
  459. def _extract_note_id_from_text(response: str) -> str | None:
  460. if not response:
  461. return None
  462. match = re.search(r"ID:\s*([^\n]+)", response)
  463. if not match:
  464. return None
  465. return match.group(1).strip()
  466. def run_deep_research(topic: str, config: Configuration | None = None) -> SummaryStateOutput:
  467. """Convenience function mirroring the class-based API."""
  468. agent = DeepResearchAgent(config=config)
  469. return agent.run(topic)