mermaid_validator_tool.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import re
  2. from typing import Any, Dict, List
  3. from hello_agents.tools.base import Tool, ToolParameter
  4. from hello_agents.tools.response import ToolResponse
  5. from hello_agents.tools.errors import ToolErrorCode
  6. MERMAID_PREFIXES = (
  7. "flowchart",
  8. "graph",
  9. "sequenceDiagram",
  10. "classDiagram",
  11. "stateDiagram",
  12. "erDiagram",
  13. "journey",
  14. "gantt",
  15. "pie",
  16. "gitGraph",
  17. "mindmap",
  18. "timeline",
  19. )
  20. class MermaidValidatorTool(Tool):
  21. def __init__(self):
  22. super().__init__(
  23. name="MermaidValidatorTool",
  24. description="校验并修复 Mermaid 代码,返回可渲染代码",
  25. )
  26. def get_parameters(self) -> List[ToolParameter]:
  27. return [
  28. ToolParameter(
  29. name="code",
  30. type="string",
  31. description="待校验的 Mermaid 代码",
  32. required=True,
  33. )
  34. ]
  35. def run(self, parameters: Dict[str, Any]) -> ToolResponse:
  36. code = str(parameters.get("code", "")).strip()
  37. if not code:
  38. return ToolResponse.error(
  39. code=ToolErrorCode.INVALID_PARAM,
  40. message="code 参数不能为空",
  41. )
  42. normalized = self._normalize(code)
  43. valid, errors = self._validate_structure(normalized)
  44. if valid:
  45. return ToolResponse.success(
  46. text=f"VALID\n{normalized}",
  47. data={"valid": True, "fixed_code": normalized, "errors": []},
  48. )
  49. return ToolResponse.partial(
  50. text=f"INVALID\n{normalized}\n错误: {'; '.join(errors)}",
  51. data={"valid": False, "fixed_code": normalized, "errors": errors},
  52. )
  53. def _normalize(self, code: str) -> str:
  54. code = code.strip()
  55. code = code.replace("```mermaid", "").replace("```", "").strip()
  56. code = code.replace("→", "-->")
  57. lines = [ln.rstrip() for ln in code.splitlines() if ln.strip()]
  58. if not lines:
  59. return "flowchart TD\n A[空图]"
  60. first = lines[0].strip()
  61. if not first.startswith(MERMAID_PREFIXES):
  62. # 兜底为 flowchart
  63. lines.insert(0, "flowchart TD")
  64. return "\n".join(lines)
  65. def _validate_structure(self, code: str):
  66. errors = []
  67. lines = code.splitlines()
  68. if not lines:
  69. return False, ["代码为空"]
  70. first = lines[0].strip()
  71. if not first.startswith(MERMAID_PREFIXES):
  72. errors.append("缺少 Mermaid 图类型声明")
  73. bracket_pairs = [("(", ")"), ("[", "]"), ("{", "}")]
  74. for left, right in bracket_pairs:
  75. if code.count(left) != code.count(right):
  76. errors.append(f"括号不匹配: {left}{right}")
  77. # flowchart 常见错误:仅有声明但无节点
  78. if first.startswith(("flowchart", "graph")):
  79. has_node = any(re.search(r"\w+\s*-->|\w+\[|\w+\(|\w+\{", ln) for ln in lines[1:])
  80. if not has_node:
  81. errors.append("flowchart 缺少节点或连线")
  82. return len(errors) == 0, errors