runner.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from __future__ import annotations
  2. from threading import Thread
  3. from time import perf_counter
  4. from backend.agents.registry import AgentRegistry
  5. from backend.events import event_logger
  6. from backend.models import AgentRequest, TaskRecord, TaskStatus
  7. from backend.tasks.manager import TaskManager
  8. class TaskRunner:
  9. def __init__(self, registry: AgentRegistry, manager: TaskManager) -> None:
  10. self.registry = registry
  11. self.manager = manager
  12. def run(self, task_id: str) -> TaskRecord:
  13. return self._run_now(task_id)
  14. def start_background(self, task_id: str) -> TaskRecord:
  15. task = self.manager.get(task_id)
  16. if task.status == TaskStatus.running:
  17. return task
  18. task = self.manager.update_status(task_id, TaskStatus.running)
  19. thread = Thread(target=self._run_now, args=(task_id,), daemon=True)
  20. thread.start()
  21. return task
  22. def _run_now(self, task_id: str) -> TaskRecord:
  23. task = self.manager.update_status(task_id, TaskStatus.running)
  24. event_logger.emit("task_started", agent_id=task.agent_id, task_id=task_id)
  25. started = perf_counter()
  26. try:
  27. agent = self.registry.get(task.agent_id)
  28. response = agent.run(AgentRequest(input=task.input, context=task.metadata, task_id=task_id))
  29. elapsed = round(perf_counter() - started, 3)
  30. artifacts = dict(response.artifacts)
  31. artifacts["elapsed_seconds"] = elapsed
  32. task = self.manager.complete(task_id, output=response.output, artifacts=artifacts)
  33. event_logger.emit(
  34. "task_completed",
  35. agent_id=task.agent_id,
  36. task_id=task_id,
  37. payload={"elapsed_seconds": elapsed},
  38. )
  39. return task
  40. except Exception as exc:
  41. elapsed = round(perf_counter() - started, 3)
  42. task = self.manager.fail(task_id, str(exc))
  43. event_logger.emit(
  44. "task_failed",
  45. agent_id=task.agent_id,
  46. task_id=task_id,
  47. payload={"error": str(exc), "elapsed_seconds": elapsed},
  48. )
  49. return task