|
|
@@ -0,0 +1,551 @@
|
|
|
+"""Orchestrator coordinating the deep research workflow."""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import logging
|
|
|
+import re
|
|
|
+from pathlib import Path
|
|
|
+from queue import Empty, Queue
|
|
|
+from threading import Lock, Thread
|
|
|
+from typing import Any, Callable, Iterator
|
|
|
+
|
|
|
+from hello_agents import HelloAgentsLLM, ToolAwareSimpleAgent
|
|
|
+from hello_agents.tools import ToolRegistry
|
|
|
+from hello_agents.tools.builtin.note_tool import NoteTool
|
|
|
+
|
|
|
+from config import Configuration
|
|
|
+from prompts import (
|
|
|
+ report_writer_instructions,
|
|
|
+ task_summarizer_instructions,
|
|
|
+ todo_planner_system_prompt,
|
|
|
+)
|
|
|
+from models import SummaryState, SummaryStateOutput, TodoItem
|
|
|
+from services.planner import PlanningService
|
|
|
+from services.reporter import ReportingService
|
|
|
+from services.search import dispatch_search, prepare_research_context
|
|
|
+from services.summarizer import SummarizationService
|
|
|
+from services.tool_events import ToolCallTracker
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class DeepResearchAgent:
|
|
|
+ """Coordinator orchestrating TODO-based research workflow using HelloAgents."""
|
|
|
+
|
|
|
+ def __init__(self, config: Configuration | None = None) -> None:
|
|
|
+ """Initialise the coordinator with configuration and shared tools."""
|
|
|
+ self.config = config or Configuration.from_env()
|
|
|
+ self.llm = self._init_llm()
|
|
|
+
|
|
|
+ self.note_tool = (
|
|
|
+ NoteTool(workspace=self.config.notes_workspace)
|
|
|
+ if self.config.enable_notes
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ self.tools_registry: ToolRegistry | None = None
|
|
|
+ if self.note_tool:
|
|
|
+ registry = ToolRegistry()
|
|
|
+ registry.register_tool(self.note_tool)
|
|
|
+ self.tools_registry = registry
|
|
|
+
|
|
|
+ self._tool_tracker = ToolCallTracker(
|
|
|
+ self.config.notes_workspace if self.config.enable_notes else None
|
|
|
+ )
|
|
|
+ self._tool_event_sink_enabled = False
|
|
|
+ self._state_lock = Lock()
|
|
|
+
|
|
|
+ self.todo_agent = self._create_tool_aware_agent(
|
|
|
+ name="研究规划专家",
|
|
|
+ system_prompt=todo_planner_system_prompt.strip(),
|
|
|
+ )
|
|
|
+ self.report_agent = self._create_tool_aware_agent(
|
|
|
+ name="报告撰写专家",
|
|
|
+ system_prompt=report_writer_instructions.strip(),
|
|
|
+ )
|
|
|
+
|
|
|
+ self._summarizer_factory: Callable[[], ToolAwareSimpleAgent] = lambda: self._create_tool_aware_agent( # noqa: E501
|
|
|
+ name="任务总结专家",
|
|
|
+ system_prompt=task_summarizer_instructions.strip(),
|
|
|
+ )
|
|
|
+
|
|
|
+ self.planner = PlanningService(self.todo_agent, self.config)
|
|
|
+ self.summarizer = SummarizationService(self._summarizer_factory, self.config)
|
|
|
+ self.reporting = ReportingService(self.report_agent, self.config)
|
|
|
+ self._last_search_notices: list[str] = []
|
|
|
+
|
|
|
+ # ------------------------------------------------------------------
|
|
|
+ # Public API
|
|
|
+ # ------------------------------------------------------------------
|
|
|
+ def _init_llm(self) -> HelloAgentsLLM:
|
|
|
+ """Instantiate HelloAgentsLLM following configuration preferences."""
|
|
|
+ llm_kwargs: dict[str, Any] = {"temperature": 0.0}
|
|
|
+
|
|
|
+ model_id = self.config.llm_model_id or self.config.local_llm
|
|
|
+ if model_id:
|
|
|
+ llm_kwargs["model"] = model_id
|
|
|
+
|
|
|
+ provider = (self.config.llm_provider or "").strip()
|
|
|
+ if provider:
|
|
|
+ llm_kwargs["provider"] = provider
|
|
|
+
|
|
|
+ if provider == "ollama":
|
|
|
+ llm_kwargs["base_url"] = self.config.sanitized_ollama_url()
|
|
|
+ if self.config.llm_api_key:
|
|
|
+ llm_kwargs["api_key"] = self.config.llm_api_key
|
|
|
+ else:
|
|
|
+ llm_kwargs["api_key"] = "ollama"
|
|
|
+ elif provider == "lmstudio":
|
|
|
+ llm_kwargs["base_url"] = self.config.lmstudio_base_url
|
|
|
+ if self.config.llm_api_key:
|
|
|
+ llm_kwargs["api_key"] = self.config.llm_api_key
|
|
|
+ else:
|
|
|
+ if self.config.llm_base_url:
|
|
|
+ llm_kwargs["base_url"] = self.config.llm_base_url
|
|
|
+ if self.config.llm_api_key:
|
|
|
+ llm_kwargs["api_key"] = self.config.llm_api_key
|
|
|
+
|
|
|
+ return HelloAgentsLLM(**llm_kwargs)
|
|
|
+
|
|
|
+ def _create_tool_aware_agent(self, *, name: str, system_prompt: str) -> ToolAwareSimpleAgent:
|
|
|
+ """Instantiate a ToolAwareSimpleAgent sharing tool registry and tracker."""
|
|
|
+ return ToolAwareSimpleAgent(
|
|
|
+ name=name,
|
|
|
+ llm=self.llm,
|
|
|
+ system_prompt=system_prompt,
|
|
|
+ enable_tool_calling=self.tools_registry is not None,
|
|
|
+ tool_registry=self.tools_registry,
|
|
|
+ tool_call_listener=self._tool_tracker.record,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _set_tool_event_sink(self, sink: Callable[[dict[str, Any]], None] | None) -> None:
|
|
|
+ """Enable or disable immediate tool event callbacks."""
|
|
|
+ self._tool_event_sink_enabled = sink is not None
|
|
|
+ self._tool_tracker.set_event_sink(sink)
|
|
|
+
|
|
|
+ def run(self, topic: str) -> SummaryStateOutput:
|
|
|
+ """Execute the research workflow and return the final report."""
|
|
|
+ state = SummaryState(research_topic=topic)
|
|
|
+ state.todo_items = self.planner.plan_todo_list(state)
|
|
|
+ self._drain_tool_events(state)
|
|
|
+
|
|
|
+ if not state.todo_items:
|
|
|
+ logger.info("No TODO items generated; falling back to single task")
|
|
|
+ state.todo_items = [self.planner.create_fallback_task(state)]
|
|
|
+
|
|
|
+ for task in state.todo_items:
|
|
|
+ self._execute_task(state, task, emit_stream=False)
|
|
|
+
|
|
|
+ report = self.reporting.generate_report(state)
|
|
|
+ self._drain_tool_events(state)
|
|
|
+ state.structured_report = report
|
|
|
+ state.running_summary = report
|
|
|
+ self._persist_final_report(state, report)
|
|
|
+
|
|
|
+ return SummaryStateOutput(
|
|
|
+ running_summary=report,
|
|
|
+ report_markdown=report,
|
|
|
+ todo_items=state.todo_items,
|
|
|
+ )
|
|
|
+
|
|
|
+ def run_stream(self, topic: str) -> Iterator[dict[str, Any]]:
|
|
|
+ """Execute the workflow yielding incremental progress events."""
|
|
|
+ state = SummaryState(research_topic=topic)
|
|
|
+ logger.debug("Starting streaming research: topic=%s", topic)
|
|
|
+ yield {"type": "status", "message": "初始化研究流程"}
|
|
|
+
|
|
|
+ state.todo_items = self.planner.plan_todo_list(state)
|
|
|
+ for event in self._drain_tool_events(state, step=0):
|
|
|
+ yield event
|
|
|
+ if not state.todo_items:
|
|
|
+ state.todo_items = [self.planner.create_fallback_task(state)]
|
|
|
+
|
|
|
+ channel_map: dict[int, dict[str, Any]] = {}
|
|
|
+ for index, task in enumerate(state.todo_items, start=1):
|
|
|
+ token = f"task_{task.id}"
|
|
|
+ task.stream_token = token
|
|
|
+ channel_map[task.id] = {"step": index, "token": token}
|
|
|
+
|
|
|
+ yield {
|
|
|
+ "type": "todo_list",
|
|
|
+ "tasks": [self._serialize_task(t) for t in state.todo_items],
|
|
|
+ "step": 0,
|
|
|
+ }
|
|
|
+
|
|
|
+ event_queue: Queue[dict[str, Any]] = Queue()
|
|
|
+
|
|
|
+ def enqueue(
|
|
|
+ event: dict[str, Any],
|
|
|
+ *,
|
|
|
+ task: TodoItem | None = None,
|
|
|
+ step_override: int | None = None,
|
|
|
+ ) -> None:
|
|
|
+ payload = dict(event)
|
|
|
+ target_task_id = payload.get("task_id")
|
|
|
+ if task is not None:
|
|
|
+ target_task_id = task.id
|
|
|
+ payload["task_id"] = task.id
|
|
|
+
|
|
|
+ channel = channel_map.get(target_task_id) if target_task_id is not None else None
|
|
|
+ if channel:
|
|
|
+ payload.setdefault("step", channel["step"])
|
|
|
+ payload["stream_token"] = channel["token"]
|
|
|
+ if step_override is not None:
|
|
|
+ payload["step"] = step_override
|
|
|
+ event_queue.put(payload)
|
|
|
+
|
|
|
+ def tool_event_sink(event: dict[str, Any]) -> None:
|
|
|
+ enqueue(event)
|
|
|
+
|
|
|
+ self._set_tool_event_sink(tool_event_sink)
|
|
|
+
|
|
|
+ threads: list[Thread] = []
|
|
|
+
|
|
|
+ def worker(task: TodoItem, step: int) -> None:
|
|
|
+ try:
|
|
|
+ enqueue(
|
|
|
+ {
|
|
|
+ "type": "task_status",
|
|
|
+ "task_id": task.id,
|
|
|
+ "status": "in_progress",
|
|
|
+ "title": task.title,
|
|
|
+ "intent": task.intent,
|
|
|
+ "note_id": task.note_id,
|
|
|
+ "note_path": task.note_path,
|
|
|
+ },
|
|
|
+ task=task,
|
|
|
+ )
|
|
|
+
|
|
|
+ for event in self._execute_task(state, task, emit_stream=True, step=step):
|
|
|
+ enqueue(event, task=task)
|
|
|
+ except Exception as exc: # pragma: no cover - defensive guardrail
|
|
|
+ logger.exception("Task execution failed", exc_info=exc)
|
|
|
+ enqueue(
|
|
|
+ {
|
|
|
+ "type": "task_status",
|
|
|
+ "task_id": task.id,
|
|
|
+ "status": "failed",
|
|
|
+ "detail": str(exc),
|
|
|
+ "title": task.title,
|
|
|
+ "intent": task.intent,
|
|
|
+ "note_id": task.note_id,
|
|
|
+ "note_path": task.note_path,
|
|
|
+ },
|
|
|
+ task=task,
|
|
|
+ )
|
|
|
+ finally:
|
|
|
+ enqueue({"type": "__task_done__", "task_id": task.id})
|
|
|
+
|
|
|
+ for task in state.todo_items:
|
|
|
+ step = channel_map.get(task.id, {}).get("step", 0)
|
|
|
+ thread = Thread(target=worker, args=(task, step), daemon=True)
|
|
|
+ threads.append(thread)
|
|
|
+ thread.start()
|
|
|
+
|
|
|
+ active_workers = len(state.todo_items)
|
|
|
+ finished_workers = 0
|
|
|
+
|
|
|
+ try:
|
|
|
+ while finished_workers < active_workers:
|
|
|
+ event = event_queue.get()
|
|
|
+ if event.get("type") == "__task_done__":
|
|
|
+ finished_workers += 1
|
|
|
+ continue
|
|
|
+ yield event
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ event = event_queue.get_nowait()
|
|
|
+ except Empty:
|
|
|
+ break
|
|
|
+ if event.get("type") != "__task_done__":
|
|
|
+ yield event
|
|
|
+ finally:
|
|
|
+ self._set_tool_event_sink(None)
|
|
|
+ for thread in threads:
|
|
|
+ thread.join()
|
|
|
+
|
|
|
+ report = self.reporting.generate_report(state)
|
|
|
+ final_step = len(state.todo_items) + 1
|
|
|
+ for event in self._drain_tool_events(state, step=final_step):
|
|
|
+ yield event
|
|
|
+ state.structured_report = report
|
|
|
+ state.running_summary = report
|
|
|
+
|
|
|
+ note_event = self._persist_final_report(state, report)
|
|
|
+ if note_event:
|
|
|
+ yield note_event
|
|
|
+
|
|
|
+ yield {
|
|
|
+ "type": "final_report",
|
|
|
+ "report": report,
|
|
|
+ "note_id": state.report_note_id,
|
|
|
+ "note_path": state.report_note_path,
|
|
|
+ }
|
|
|
+ yield {"type": "done"}
|
|
|
+
|
|
|
+ # ------------------------------------------------------------------
|
|
|
+ # Execution helpers
|
|
|
+ # ------------------------------------------------------------------
|
|
|
+ def _execute_task(
|
|
|
+ self,
|
|
|
+ state: SummaryState,
|
|
|
+ task: TodoItem,
|
|
|
+ *,
|
|
|
+ emit_stream: bool,
|
|
|
+ step: int | None = None,
|
|
|
+ ) -> Iterator[dict[str, Any]]:
|
|
|
+ """Run search + summarization for a single task."""
|
|
|
+ task.status = "in_progress"
|
|
|
+
|
|
|
+ search_result, notices, answer_text, backend = dispatch_search(
|
|
|
+ task.query,
|
|
|
+ self.config,
|
|
|
+ state.research_loop_count,
|
|
|
+ )
|
|
|
+ self._last_search_notices = notices
|
|
|
+ task.notices = notices
|
|
|
+
|
|
|
+ if emit_stream:
|
|
|
+ for event in self._drain_tool_events(state, step=step):
|
|
|
+ yield event
|
|
|
+ else:
|
|
|
+ self._drain_tool_events(state)
|
|
|
+
|
|
|
+ if notices and emit_stream:
|
|
|
+ for notice in notices:
|
|
|
+ if notice:
|
|
|
+ yield {
|
|
|
+ "type": "status",
|
|
|
+ "message": notice,
|
|
|
+ "task_id": task.id,
|
|
|
+ "step": step,
|
|
|
+ }
|
|
|
+
|
|
|
+ if not search_result or not search_result.get("results"):
|
|
|
+ task.status = "skipped"
|
|
|
+ if emit_stream:
|
|
|
+ for event in self._drain_tool_events(state, step=step):
|
|
|
+ yield event
|
|
|
+ yield {
|
|
|
+ "type": "task_status",
|
|
|
+ "task_id": task.id,
|
|
|
+ "status": "skipped",
|
|
|
+ "title": task.title,
|
|
|
+ "intent": task.intent,
|
|
|
+ "note_id": task.note_id,
|
|
|
+ "note_path": task.note_path,
|
|
|
+ "step": step,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ self._drain_tool_events(state)
|
|
|
+ return
|
|
|
+ else:
|
|
|
+ if not emit_stream:
|
|
|
+ self._drain_tool_events(state)
|
|
|
+
|
|
|
+ sources_summary, context = prepare_research_context(
|
|
|
+ search_result,
|
|
|
+ answer_text,
|
|
|
+ self.config,
|
|
|
+ )
|
|
|
+
|
|
|
+ task.sources_summary = sources_summary
|
|
|
+
|
|
|
+ with self._state_lock:
|
|
|
+ state.web_research_results.append(context)
|
|
|
+ state.sources_gathered.append(sources_summary)
|
|
|
+ state.research_loop_count += 1
|
|
|
+
|
|
|
+ summary_text: str | None = None
|
|
|
+
|
|
|
+ if emit_stream:
|
|
|
+ for event in self._drain_tool_events(state, step=step):
|
|
|
+ yield event
|
|
|
+ yield {
|
|
|
+ "type": "sources",
|
|
|
+ "task_id": task.id,
|
|
|
+ "latest_sources": sources_summary,
|
|
|
+ "raw_context": context,
|
|
|
+ "step": step,
|
|
|
+ "backend": backend,
|
|
|
+ "note_id": task.note_id,
|
|
|
+ "note_path": task.note_path,
|
|
|
+ }
|
|
|
+
|
|
|
+ summary_stream, summary_getter = self.summarizer.stream_task_summary(state, task, context)
|
|
|
+ try:
|
|
|
+ for event in self._drain_tool_events(state, step=step):
|
|
|
+ yield event
|
|
|
+ for chunk in summary_stream:
|
|
|
+ if chunk:
|
|
|
+ yield {
|
|
|
+ "type": "task_summary_chunk",
|
|
|
+ "task_id": task.id,
|
|
|
+ "content": chunk,
|
|
|
+ "note_id": task.note_id,
|
|
|
+ "step": step,
|
|
|
+ }
|
|
|
+ for event in self._drain_tool_events(state, step=step):
|
|
|
+ yield event
|
|
|
+ finally:
|
|
|
+ summary_text = summary_getter()
|
|
|
+ else:
|
|
|
+ summary_text = self.summarizer.summarize_task(state, task, context)
|
|
|
+ self._drain_tool_events(state)
|
|
|
+
|
|
|
+ task.summary = summary_text.strip() if summary_text else "暂无可用信息"
|
|
|
+ task.status = "completed"
|
|
|
+
|
|
|
+ if emit_stream:
|
|
|
+ for event in self._drain_tool_events(state, step=step):
|
|
|
+ yield event
|
|
|
+ yield {
|
|
|
+ "type": "task_status",
|
|
|
+ "task_id": task.id,
|
|
|
+ "status": "completed",
|
|
|
+ "summary": task.summary,
|
|
|
+ "sources_summary": task.sources_summary,
|
|
|
+ "note_id": task.note_id,
|
|
|
+ "note_path": task.note_path,
|
|
|
+ "step": step,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ self._drain_tool_events(state)
|
|
|
+
|
|
|
+ def _drain_tool_events(
|
|
|
+ self,
|
|
|
+ state: SummaryState,
|
|
|
+ *,
|
|
|
+ step: int | None = None,
|
|
|
+ ) -> list[dict[str, Any]]:
|
|
|
+ """Proxy to the shared tool call tracker."""
|
|
|
+ events = self._tool_tracker.drain(state, step=step)
|
|
|
+ if self._tool_event_sink_enabled:
|
|
|
+ return []
|
|
|
+ return events
|
|
|
+
|
|
|
+ @property
|
|
|
+ def _tool_call_events(self) -> list[dict[str, Any]]:
|
|
|
+ """Expose recorded tool events for legacy integrations."""
|
|
|
+ return self._tool_tracker.as_dicts()
|
|
|
+
|
|
|
+ def _serialize_task(self, task: TodoItem) -> dict[str, Any]:
|
|
|
+ """Convert task dataclass to serializable dict for frontend."""
|
|
|
+ return {
|
|
|
+ "id": task.id,
|
|
|
+ "title": task.title,
|
|
|
+ "intent": task.intent,
|
|
|
+ "query": task.query,
|
|
|
+ "status": task.status,
|
|
|
+ "summary": task.summary,
|
|
|
+ "sources_summary": task.sources_summary,
|
|
|
+ "note_id": task.note_id,
|
|
|
+ "note_path": task.note_path,
|
|
|
+ "stream_token": task.stream_token,
|
|
|
+ }
|
|
|
+
|
|
|
+ def _persist_final_report(self, state: SummaryState, report: str) -> dict[str, Any] | None:
|
|
|
+ if not self.note_tool or not report or not report.strip():
|
|
|
+ return None
|
|
|
+
|
|
|
+ note_title = f"研究报告:{state.research_topic}".strip() or "研究报告"
|
|
|
+ tags = ["deep_research", "report"]
|
|
|
+ content = report.strip()
|
|
|
+
|
|
|
+ note_id = self._find_existing_report_note_id(state)
|
|
|
+ response = ""
|
|
|
+
|
|
|
+ if note_id:
|
|
|
+ response = self.note_tool.run(
|
|
|
+ {
|
|
|
+ "action": "update",
|
|
|
+ "note_id": note_id,
|
|
|
+ "title": note_title,
|
|
|
+ "note_type": "conclusion",
|
|
|
+ "tags": tags,
|
|
|
+ "content": content,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ if response.startswith("❌"):
|
|
|
+ note_id = None
|
|
|
+
|
|
|
+ if not note_id:
|
|
|
+ response = self.note_tool.run(
|
|
|
+ {
|
|
|
+ "action": "create",
|
|
|
+ "title": note_title,
|
|
|
+ "note_type": "conclusion",
|
|
|
+ "tags": tags,
|
|
|
+ "content": content,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ note_id = self._extract_note_id_from_text(response)
|
|
|
+
|
|
|
+ if not note_id:
|
|
|
+ return None
|
|
|
+
|
|
|
+ state.report_note_id = note_id
|
|
|
+ if self.config.notes_workspace:
|
|
|
+ note_path = Path(self.config.notes_workspace) / f"{note_id}.md"
|
|
|
+ state.report_note_path = str(note_path)
|
|
|
+ else:
|
|
|
+ note_path = None
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "type": "report_note",
|
|
|
+ "note_id": note_id,
|
|
|
+ "title": note_title,
|
|
|
+ "content": content,
|
|
|
+ }
|
|
|
+ if note_path:
|
|
|
+ payload["note_path"] = str(note_path)
|
|
|
+
|
|
|
+ return payload
|
|
|
+
|
|
|
+ def _find_existing_report_note_id(self, state: SummaryState) -> str | None:
|
|
|
+ if state.report_note_id:
|
|
|
+ return state.report_note_id
|
|
|
+
|
|
|
+ for event in reversed(self._tool_tracker.as_dicts()):
|
|
|
+ if event.get("tool") != "note":
|
|
|
+ continue
|
|
|
+
|
|
|
+ parameters = event.get("parsed_parameters") or {}
|
|
|
+ if not isinstance(parameters, dict):
|
|
|
+ continue
|
|
|
+
|
|
|
+ action = parameters.get("action")
|
|
|
+ if action not in {"create", "update"}:
|
|
|
+ continue
|
|
|
+
|
|
|
+ note_type = parameters.get("note_type")
|
|
|
+ if note_type != "conclusion":
|
|
|
+ title = parameters.get("title")
|
|
|
+ if not (isinstance(title, str) and title.startswith("研究报告")):
|
|
|
+ continue
|
|
|
+
|
|
|
+ note_id = parameters.get("note_id")
|
|
|
+ if not note_id:
|
|
|
+ note_id = self._tool_tracker._extract_note_id(event.get("result", "")) # type: ignore[attr-defined]
|
|
|
+
|
|
|
+ if note_id:
|
|
|
+ return note_id
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _extract_note_id_from_text(response: str) -> str | None:
|
|
|
+ if not response:
|
|
|
+ return None
|
|
|
+
|
|
|
+ match = re.search(r"ID:\s*([^\n]+)", response)
|
|
|
+ if not match:
|
|
|
+ return None
|
|
|
+
|
|
|
+ return match.group(1).strip()
|
|
|
+
|
|
|
+
|
|
|
+def run_deep_research(topic: str, config: Configuration | None = None) -> SummaryStateOutput:
|
|
|
+ """Convenience function mirroring the class-based API."""
|
|
|
+ agent = DeepResearchAgent(config=config)
|
|
|
+ return agent.run(topic)
|