mermaid_agent_service.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import asyncio
  2. from typing import Any, AsyncGenerator, Dict
  3. from app.config import settings
  4. from app.agents.mermaid.code_utils import apply_direction, prune_complexity
  5. from app.agents.mermaid.pipeline import MermaidPipeline
  6. from app.tools.mermaid_validator_tool import MermaidValidatorTool
  7. class MermaidAgentService:
  8. def __init__(self):
  9. self.validator = MermaidValidatorTool()
  10. self.pipeline = MermaidPipeline(self.validator)
  11. async def stream_chat(self, mode: str, prompt: str, direction: str = "TD") -> AsyncGenerator[Dict[str, Any], None]:
  12. yield {"type": "status", "phase": "start", "message": "开始生成"}
  13. yield {"type": "status", "phase": "generating", "message": "模型生成中"}
  14. try:
  15. llm_timeout = max(30, int(settings.llm_timeout))
  16. standard_timeout = llm_timeout * 2 + 20
  17. single_timeout = llm_timeout + 20
  18. validate_timeout = 20
  19. optimized_text = ""
  20. generated_from_optimized = False
  21. if mode == "standard":
  22. yield {"type": "status", "phase": "optimizing", "message": "文本优化中"}
  23. standard_result = await asyncio.wait_for(
  24. asyncio.to_thread(self.pipeline.generate_standard, prompt), timeout=standard_timeout
  25. )
  26. optimized_text = standard_result.get("optimized_text", "")
  27. raw_code = standard_result.get("mermaid_code", "")
  28. generated_from_optimized = bool(standard_result.get("generated_from_optimized", False))
  29. yield {"type": "status", "phase": "creating", "message": "基于优化文本生成流程图"}
  30. else:
  31. raw_code = await asyncio.wait_for(
  32. asyncio.to_thread(self.pipeline.generate_once, mode, prompt), timeout=single_timeout
  33. )
  34. extracted = prune_complexity(raw_code, mode)
  35. extracted = apply_direction(extracted, direction)
  36. yield {"type": "status", "phase": "validating", "message": "语法校验中"}
  37. validation = await asyncio.wait_for(
  38. asyncio.to_thread(self.pipeline.post_validate, extracted), timeout=validate_timeout
  39. )
  40. yield {
  41. "type": "result",
  42. "mode": mode,
  43. "valid": validation["valid"],
  44. "attempts": validation["attempts"],
  45. "mermaid_code": validation["mermaid_code"],
  46. "optimized_text": optimized_text,
  47. "generated_from_optimized": generated_from_optimized,
  48. "message": validation.get("message", ""),
  49. }
  50. yield {"type": "done"}
  51. except asyncio.TimeoutError:
  52. yield {
  53. "type": "error",
  54. "phase": "timeout",
  55. "message": "生成超时,请简化输入后重试。",
  56. }
  57. yield {"type": "done"}
  58. except Exception as exc:
  59. yield {
  60. "type": "error",
  61. "phase": "exception",
  62. "message": f"生成失败: {str(exc)}",
  63. }
  64. yield {"type": "done"}