pipeline.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import Any, Dict
  2. from app.config import settings
  3. from app.prompts.standard_prompt import STANDARD_PROMPT
  4. from app.services.llm_service import LLMService
  5. from app.tools.mermaid_validator_tool import MermaidValidatorTool
  6. from .agent_factory import build_agent
  7. from .code_utils import extract_mermaid, extract_optimized_text
  8. class MermaidPipeline:
  9. def __init__(self, validator: MermaidValidatorTool):
  10. self.validator = validator
  11. def generate_once(self, mode: str, prompt: str) -> str:
  12. agent = build_agent(mode, self.validator)
  13. agent.clear_history()
  14. result_text = agent.run(prompt)
  15. return extract_mermaid(result_text)
  16. def generate_standard(self, prompt: str) -> Dict[str, str]:
  17. optimize_messages = [
  18. {
  19. "role": "system",
  20. "content": (
  21. STANDARD_PROMPT
  22. + "\n补充要求:你现在只执行【第一步:文本优化】。"
  23. "只输出优化后的完整文本,不要输出标题,不要输出Mermaid代码,不要代码块。"
  24. ),
  25. },
  26. {"role": "user", "content": prompt},
  27. ]
  28. optimized_resp = LLMService.create_llm().invoke(optimize_messages)
  29. optimized_text = extract_optimized_text(optimized_resp.content)
  30. source_text = optimized_text.strip() or (optimized_resp.content or "").strip()
  31. code_agent = build_agent("standard-code", self.validator)
  32. code_agent.clear_history()
  33. raw_code_text = code_agent.run(source_text)
  34. return {
  35. "optimized_text": source_text,
  36. "mermaid_code": extract_mermaid(raw_code_text),
  37. "generated_from_optimized": bool(source_text),
  38. }
  39. def repair_once(self, bad_code: str, reason: str) -> str:
  40. messages = [
  41. {"role": "system", "content": "你是 Mermaid 修复器,只输出修复后的 Mermaid 代码。"},
  42. {
  43. "role": "user",
  44. "content": (
  45. "请修复以下 Mermaid 代码并确保可渲染。"
  46. f"\n错误信息: {reason}\n\n代码:\n{bad_code}"
  47. ),
  48. },
  49. ]
  50. response = LLMService.create_llm().invoke(messages)
  51. return extract_mermaid(response.content)
  52. def post_validate(self, code: str) -> Dict[str, Any]:
  53. current = code
  54. attempts = 0
  55. repair_limit = min(settings.validator_max_retries, 1)
  56. while attempts <= repair_limit:
  57. attempts += 1
  58. result = self.validator.run({"code": current})
  59. valid = bool(result.data.get("valid"))
  60. fixed_code = result.data.get("fixed_code", current)
  61. if valid:
  62. return {
  63. "valid": True,
  64. "attempts": attempts,
  65. "mermaid_code": fixed_code,
  66. "message": "validated",
  67. }
  68. if attempts > repair_limit:
  69. return {
  70. "valid": False,
  71. "attempts": attempts,
  72. "mermaid_code": fixed_code,
  73. "message": "; ".join(result.data.get("errors", [])) or "validate failed",
  74. }
  75. reason = "; ".join(result.data.get("errors", [])) or "unknown syntax issue"
  76. current = self.repair_once(fixed_code, reason)
  77. return {
  78. "valid": False,
  79. "attempts": attempts,
  80. "mermaid_code": current,
  81. "message": "validate failed",
  82. }