|
|
@@ -3,7 +3,6 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import logging
|
|
|
-import re
|
|
|
from collections.abc import Callable, Iterator
|
|
|
from pathlib import Path
|
|
|
from queue import Empty, Queue
|
|
|
@@ -18,7 +17,6 @@ from config import Configuration
|
|
|
from models import SummaryState, SummaryStateOutput, TodoItem
|
|
|
from prompts import (
|
|
|
report_writer_instructions,
|
|
|
- script_writer_instructions,
|
|
|
task_summarizer_instructions,
|
|
|
todo_planner_system_prompt,
|
|
|
)
|
|
|
@@ -77,11 +75,6 @@ class DeepResearchAgent:
|
|
|
system_prompt=report_writer_instructions.strip(),
|
|
|
llm=self.smart_llm,
|
|
|
)
|
|
|
- self.script_agent = self._create_tool_aware_agent(
|
|
|
- name="脚本策划专家",
|
|
|
- system_prompt=script_writer_instructions.strip(),
|
|
|
- llm=self.default_llm,
|
|
|
- )
|
|
|
|
|
|
self._summarizer_factory: Callable[[], ToolAwareSimpleAgent] = lambda: self._create_tool_aware_agent( # noqa: E501
|
|
|
name="任务总结专家",
|
|
|
@@ -92,22 +85,16 @@ class DeepResearchAgent:
|
|
|
self.planner = PlanningService(self.todo_agent, self.config)
|
|
|
self.summarizer = SummarizationService(self._summarizer_factory, self.config)
|
|
|
self.reporting = ReportingService(self.report_agent, self.config)
|
|
|
- self.script_generator = ScriptGenerationService(self.script_agent, self.config)
|
|
|
+ self.script_generator = ScriptGenerationService(self.config)
|
|
|
self.audio_generator = AudioGenerationService(self.config)
|
|
|
|
|
|
self.podcast_synthesizer = PodcastSynthesisService(self.config)
|
|
|
- self._last_search_notices: list[str] = []
|
|
|
|
|
|
def cancel(self) -> None:
|
|
|
"""请求取消当前正在执行的研究任务。"""
|
|
|
logger.info("Cancel requested for research agent")
|
|
|
self._cancel_event.set()
|
|
|
|
|
|
- def _check_cancelled(self) -> None:
|
|
|
- """检查是否收到取消请求,如果是则抛出 CancelledException。"""
|
|
|
- if self._cancel_event.is_set():
|
|
|
- raise CancelledException("研究任务已被用户取消")
|
|
|
-
|
|
|
def is_cancelled(self) -> bool:
|
|
|
"""检查当前任务是否已被取消。"""
|
|
|
return self._cancel_event.is_set()
|
|
|
@@ -456,14 +443,17 @@ class DeepResearchAgent:
|
|
|
"total": total,
|
|
|
"role": role,
|
|
|
"preview": preview,
|
|
|
- "message": f"[TTS {current}/{total}] 正在为 {role} 生成语音: {preview}",
|
|
|
+ "message": f"[TTS {current}/{total}] ✓ {role} 语音生成成功",
|
|
|
})
|
|
|
return True # 返回 True 表示继续
|
|
|
|
|
|
def run_audio_generation():
|
|
|
"""在单独线程中运行音频生成"""
|
|
|
try:
|
|
|
- files = self.audio_generator.generate_audio(script, task_id, audio_progress_callback)
|
|
|
+ files = self.audio_generator.generate_audio(
|
|
|
+ script, task_id, audio_progress_callback,
|
|
|
+ cancel_event=self._cancel_event,
|
|
|
+ )
|
|
|
audio_result.append(files)
|
|
|
except Exception as e:
|
|
|
if not self.is_cancelled():
|
|
|
@@ -501,7 +491,7 @@ class DeepResearchAgent:
|
|
|
if event.get("type") == "audio_progress":
|
|
|
yield {
|
|
|
"type": "log",
|
|
|
- "message": f"[TTS {event['current']}/{event['total']}] ✓ {event['role']} 语音生成成功"
|
|
|
+ "message": f"[TTS {event['current']}/{event['total']}] ✓ {event['role']} 语音已完成"
|
|
|
}
|
|
|
except Empty:
|
|
|
continue
|
|
|
@@ -532,8 +522,14 @@ class DeepResearchAgent:
|
|
|
"stage": "synthesis",
|
|
|
"message": "正在合成完整播客音频文件...",
|
|
|
}
|
|
|
+
|
|
|
+ # 检查取消
|
|
|
+ if self.is_cancelled():
|
|
|
+ yield {"type": "cancelled", "message": "研究任务已取消"}
|
|
|
+ return
|
|
|
+
|
|
|
yield {"type": "log", "message": "使用 FFmpeg 拼接所有语音片段..."}
|
|
|
- podcast_file = self.podcast_synthesizer.synthesize_podcast(audio_files, task_id)
|
|
|
+ podcast_file = self.podcast_synthesizer.synthesize_podcast(audio_files, task_id, cancel_check=self.is_cancelled)
|
|
|
if podcast_file:
|
|
|
yield {
|
|
|
"type": "podcast_ready",
|
|
|
@@ -575,7 +571,6 @@ class DeepResearchAgent:
|
|
|
self.config,
|
|
|
state.research_loop_count,
|
|
|
)
|
|
|
- self._last_search_notices = notices
|
|
|
task.notices = notices
|
|
|
|
|
|
if emit_stream:
|
|
|
@@ -699,11 +694,6 @@ class DeepResearchAgent:
|
|
|
return []
|
|
|
return events
|
|
|
|
|
|
- @property
|
|
|
- def _tool_call_events(self) -> list[dict[str, Any]]:
|
|
|
- """为旧版集成暴露记录的工具事件。"""
|
|
|
- return self._tool_tracker.as_dicts()
|
|
|
-
|
|
|
def _serialize_task(self, task: TodoItem) -> dict[str, Any]:
|
|
|
"""将任务数据类转换为前端可序列化的字典。"""
|
|
|
return {
|
|
|
@@ -754,7 +744,7 @@ class DeepResearchAgent:
|
|
|
"content": content,
|
|
|
}
|
|
|
)
|
|
|
- note_id = self._extract_note_id_from_text(response)
|
|
|
+ note_id = self._tool_tracker._extract_note_id(response)
|
|
|
|
|
|
if not note_id:
|
|
|
return None
|
|
|
@@ -820,19 +810,4 @@ class DeepResearchAgent:
|
|
|
|
|
|
return None
|
|
|
|
|
|
- @staticmethod
|
|
|
- def _extract_note_id_from_text(response: str) -> str | None:
|
|
|
- if not response:
|
|
|
- return None
|
|
|
-
|
|
|
- match = re.search(r"ID:\s*([^\n]+)", response)
|
|
|
- if not match:
|
|
|
- return None
|
|
|
-
|
|
|
- return match.group(1).strip()
|
|
|
-
|
|
|
|
|
|
-def run_deep_research(topic: str, config: Configuration | None = None) -> SummaryStateOutput:
|
|
|
- """镜像基于类的 API 的便捷函数。"""
|
|
|
- agent = DeepResearchAgent(config=config)
|
|
|
- return agent.run(topic)
|