conversation.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. """对话管理 — Conversation 与 ConversationManager"""
  2. import uuid
  3. from datetime import datetime
  4. from typing import Optional, Dict, Any, List
  5. from .message import Message
  6. class Conversation:
  7. """单条对话线,管理一条从根到叶的消息链"""
  8. def __init__(
  9. self,
  10. conversation_id: Optional[str] = None,
  11. name: str = "",
  12. system_prompt: Optional[str] = None,
  13. metadata: Optional[Dict[str, Any]] = None,
  14. ):
  15. self.conversation_id: str = conversation_id or uuid.uuid4().hex[:12]
  16. self.name: str = name
  17. self.system_prompt: Optional[str] = system_prompt
  18. self.created_at: datetime = datetime.now()
  19. self.updated_at: datetime = self.created_at
  20. self.messages: List[Message] = []
  21. self.metadata: Dict[str, Any] = metadata or {}
  22. def add_message(self, message: Message) -> Message:
  23. message.conversation_id = self.conversation_id
  24. if self.messages:
  25. message.parent_id = self.messages[-1].message_id
  26. self.messages.append(message)
  27. self.updated_at = datetime.now()
  28. return message
  29. def get_messages(
  30. self, start: Optional[int] = None, end: Optional[int] = None
  31. ) -> List[Message]:
  32. return self.messages[start:end]
  33. def get_last_message(self) -> Optional[Message]:
  34. return self.messages[-1] if self.messages else None
  35. def get_message_by_id(self, message_id: str) -> Optional[Message]:
  36. for m in self.messages:
  37. if m.message_id == message_id:
  38. return m
  39. return None
  40. def fork(self, at_message_id: str, new_name: str = "") -> "Conversation":
  41. target_idx = -1
  42. for i, m in enumerate(self.messages):
  43. if m.message_id == at_message_id:
  44. target_idx = i
  45. break
  46. if target_idx == -1:
  47. raise ValueError(f"消息 {at_message_id} 不存在")
  48. new_conv = Conversation(
  49. name=new_name or f"{self.name} (分支)",
  50. system_prompt=self.system_prompt,
  51. metadata={**self.metadata, "forked_from": self.conversation_id},
  52. )
  53. for i, m in enumerate(self.messages[: target_idx + 1]):
  54. if i == target_idx:
  55. fork_msg = m.model_copy(deep=True)
  56. fork_msg.branch_point = True
  57. fork_msg.conversation_id = new_conv.conversation_id
  58. fork_msg.parent_id = (
  59. new_conv.messages[-1].message_id if new_conv.messages else None
  60. )
  61. new_conv.messages.append(fork_msg)
  62. else:
  63. copied = m.model_copy(deep=True)
  64. copied.conversation_id = new_conv.conversation_id
  65. copied.parent_id = (
  66. new_conv.messages[-1].message_id if new_conv.messages else None
  67. )
  68. new_conv.messages.append(copied)
  69. return new_conv
  70. def to_dict(self) -> Dict[str, Any]:
  71. return {
  72. "conversation_id": self.conversation_id,
  73. "name": self.name,
  74. "system_prompt": self.system_prompt,
  75. "created_at": self.created_at.isoformat(),
  76. "updated_at": self.updated_at.isoformat(),
  77. "messages": [m.to_dict(full=True) for m in self.messages],
  78. "metadata": self.metadata,
  79. }
  80. @classmethod
  81. def from_dict(cls, data: Dict[str, Any]) -> "Conversation":
  82. conv = cls(
  83. conversation_id=data["conversation_id"],
  84. name=data.get("name", ""),
  85. system_prompt=data.get("system_prompt"),
  86. metadata=data.get("metadata", {}),
  87. )
  88. conv.created_at = datetime.fromisoformat(data["created_at"])
  89. conv.updated_at = datetime.fromisoformat(data["updated_at"])
  90. for md in data.get("messages", []):
  91. conv.messages.append(
  92. Message(
  93. content=md["content"],
  94. role=md["role"],
  95. message_id=md.get("message_id", ""),
  96. conversation_id=md.get("conversation_id", conv.conversation_id),
  97. parent_id=md.get("parent_id"),
  98. branch_point=md.get("branch_point", False),
  99. timestamp=datetime.fromisoformat(md["timestamp"])
  100. if md.get("timestamp")
  101. else None,
  102. metadata=md.get("metadata", {}),
  103. )
  104. )
  105. return conv
  106. def to_llm_messages(self) -> List[Dict[str, str]]:
  107. return [{"role": m.role, "content": m.content} for m in self.messages]
  108. def __len__(self) -> int:
  109. return len(self.messages)
  110. def __str__(self) -> str:
  111. return f"Conversation(id={self.conversation_id}, name={self.name}, messages={len(self.messages)})"