agent.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. """Agent基类"""
  2. from abc import ABC, abstractmethod
  3. from typing import Optional, Iterator
  4. from .message import Message
  5. from .llm import HelloAgentsLLM
  6. from .config import Config
  7. from .stream import StreamEvent
  8. from .conversation_manager import ConversationManager
  9. class Agent(ABC):
  10. """Agent基类"""
  11. def __init__(
  12. self,
  13. name: str,
  14. llm: HelloAgentsLLM,
  15. system_prompt: Optional[str] = None,
  16. config: Optional[Config] = None,
  17. conversation_manager: Optional[ConversationManager] = None,
  18. ):
  19. self.name = name
  20. self.llm = llm
  21. self.system_prompt = system_prompt
  22. self.config = config or Config()
  23. self._history: list[Message] = []
  24. self.conversation_manager = conversation_manager
  25. def _resolve_history(self, conversation_id: Optional[str] = None) -> list[Message]:
  26. if self.conversation_manager and conversation_id:
  27. conv = self.conversation_manager.get_conversation(conversation_id)
  28. if conv:
  29. return conv.messages
  30. return self._history
  31. def _save_conversation_messages(
  32. self, input_text: str, response: str, conversation_id: Optional[str] = None
  33. ) -> None:
  34. if self.conversation_manager and conversation_id:
  35. self.conversation_manager.add_message(conversation_id, input_text, "user")
  36. self.conversation_manager.add_message(
  37. conversation_id, response, "assistant"
  38. )
  39. else:
  40. self.add_message(Message(input_text, "user"))
  41. self.add_message(Message(response, "assistant"))
  42. @abstractmethod
  43. def run(self, input_text: str, **kwargs) -> str:
  44. pass
  45. def stream_run(self, input_text: str, **kwargs) -> Iterator[StreamEvent]:
  46. result = self.run(input_text, **kwargs)
  47. yield StreamEvent.text(result)
  48. yield StreamEvent.done(result)
  49. def add_message(self, message: Message):
  50. self._history.append(message)
  51. def clear_history(self):
  52. self._history.clear()
  53. def get_history(self) -> list[Message]:
  54. return self._history.copy()
  55. def __str__(self) -> str:
  56. return f"Agent(name={self.name}, provider={self.llm.provider})"
  57. def __repr__(self) -> str:
  58. return self.__str__()