tool_events.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. """Utility for collecting and exposing tool call events."""
  2. from __future__ import annotations
  3. import logging
  4. import re
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from threading import Lock
  8. from typing import Any, Callable, Optional
  9. from models import SummaryState, TodoItem
  10. logger = logging.getLogger(__name__)
  11. @dataclass
  12. class ToolCallEvent:
  13. """Internal representation of a tool call event."""
  14. id: int
  15. agent: str
  16. tool: str
  17. raw_parameters: str
  18. parsed_parameters: dict[str, Any]
  19. result: str
  20. task_id: Optional[int]
  21. note_id: Optional[str]
  22. class ToolCallTracker:
  23. """Collects tool call events and converts them to SSE payloads."""
  24. def __init__(self, notes_workspace: Optional[str]) -> None:
  25. self._notes_workspace = notes_workspace
  26. self._events: list[ToolCallEvent] = []
  27. self._cursor = 0
  28. self._lock = Lock()
  29. self._event_sink: Optional[Callable[[dict[str, Any]], None]] = None
  30. def record(self, payload: dict[str, Any]) -> None:
  31. """记录模型工具调用情况,便于日志与前端展示。"""
  32. agent_name = str(payload.get("agent_name") or "unknown")
  33. tool_name = str(payload.get("tool_name") or "unknown")
  34. raw_parameters = str(payload.get("raw_parameters") or "")
  35. parsed_parameters = payload.get("parsed_parameters") or {}
  36. result_text = str(payload.get("result") or "")
  37. if not isinstance(parsed_parameters, dict):
  38. parsed_parameters = {}
  39. task_id = self._infer_task_id(parsed_parameters)
  40. note_id: Optional[str] = None
  41. if tool_name == "note":
  42. note_id = parsed_parameters.get("note_id")
  43. if note_id is None:
  44. note_id = self._extract_note_id(result_text)
  45. event = ToolCallEvent(
  46. id=len(self._events) + 1,
  47. agent=agent_name,
  48. tool=tool_name,
  49. raw_parameters=raw_parameters,
  50. parsed_parameters=parsed_parameters,
  51. result=result_text,
  52. task_id=task_id,
  53. note_id=note_id,
  54. )
  55. with self._lock:
  56. self._events.append(event)
  57. logger.info(
  58. "Tool call recorded: agent=%s tool=%s task_id=%s note_id=%s parsed_parameters=%s",
  59. agent_name,
  60. tool_name,
  61. task_id,
  62. note_id,
  63. parsed_parameters,
  64. )
  65. sink = self._event_sink
  66. if sink:
  67. sink(self._build_payload(event, step=None))
  68. # ------------------------------------------------------------------
  69. # Draining helpers
  70. # ------------------------------------------------------------------
  71. def drain(self, state: SummaryState, *, step: Optional[int] = None) -> list[dict[str, Any]]:
  72. """提取尚未消费的工具调用事件,并同步任务的 note_id。"""
  73. with self._lock:
  74. if self._cursor >= len(self._events):
  75. return []
  76. new_events = self._events[self._cursor :]
  77. self._cursor = len(self._events)
  78. if state.todo_items:
  79. for event in new_events:
  80. task_id = event.task_id
  81. note_id = event.note_id
  82. if task_id is None or not note_id:
  83. continue
  84. self._attach_note_to_task(state.todo_items, task_id, note_id)
  85. payloads: list[dict[str, Any]] = []
  86. for event in new_events:
  87. payload = self._build_payload(event, step=step)
  88. payloads.append(payload)
  89. return payloads
  90. def reset(self) -> None:
  91. """Clear recorded events."""
  92. with self._lock:
  93. self._events.clear()
  94. self._cursor = 0
  95. def as_dicts(self) -> list[dict[str, Any]]:
  96. """Expose a snapshot of raw events for backwards compatibility."""
  97. with self._lock:
  98. return [
  99. {
  100. "id": event.id,
  101. "agent": event.agent,
  102. "tool": event.tool,
  103. "raw_parameters": event.raw_parameters,
  104. "parsed_parameters": event.parsed_parameters,
  105. "result": event.result,
  106. "task_id": event.task_id,
  107. "note_id": event.note_id,
  108. }
  109. for event in self._events
  110. ]
  111. def set_event_sink(self, sink: Optional[Callable[[dict[str, Any]], None]]) -> None:
  112. """Register a callback for immediate tool event notifications."""
  113. self._event_sink = sink
  114. def _build_payload(self, event: ToolCallEvent, step: Optional[int]) -> dict[str, Any]:
  115. payload = {
  116. "type": "tool_call",
  117. "event_id": event.id,
  118. "agent": event.agent,
  119. "tool": event.tool,
  120. "parameters": event.parsed_parameters,
  121. "result": event.result,
  122. "task_id": event.task_id,
  123. "note_id": event.note_id,
  124. }
  125. if event.note_id and self._notes_workspace:
  126. note_path = Path(self._notes_workspace) / f"{event.note_id}.md"
  127. payload["note_path"] = str(note_path)
  128. if step is not None:
  129. payload["step"] = step
  130. return payload
  131. # ------------------------------------------------------------------
  132. # Internal helpers
  133. # ------------------------------------------------------------------
  134. def _attach_note_to_task(self, tasks: list[TodoItem], task_id: int, note_id: str) -> None:
  135. """Update matching TODO item with note metadata."""
  136. for task in tasks:
  137. if task.id != task_id:
  138. continue
  139. if task.note_id != note_id:
  140. task.note_id = note_id
  141. if self._notes_workspace:
  142. task.note_path = str(Path(self._notes_workspace) / f"{note_id}.md")
  143. elif task.note_path is None and self._notes_workspace:
  144. task.note_path = str(Path(self._notes_workspace) / f"{note_id}.md")
  145. break
  146. def _infer_task_id(self, parameters: dict[str, Any]) -> Optional[int]:
  147. """尝试从工具参数推断 task_id。"""
  148. if not parameters:
  149. return None
  150. if "task_id" in parameters:
  151. try:
  152. return int(parameters["task_id"])
  153. except (TypeError, ValueError):
  154. pass
  155. tags = parameters.get("tags")
  156. if isinstance(tags, list):
  157. for tag in tags:
  158. match = re.search(r"task_(\d+)", str(tag))
  159. if match:
  160. return int(match.group(1))
  161. title = parameters.get("title")
  162. if isinstance(title, str):
  163. match = re.search(r"任务\s*(\d+)", title)
  164. if match:
  165. return int(match.group(1))
  166. return None
  167. def _extract_note_id(self, response: str) -> Optional[str]:
  168. if not response:
  169. return None
  170. match = re.search(r"ID:\s*([^\n]+)", response)
  171. if match:
  172. return match.group(1).strip()
  173. return None