main.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """FastAPI entrypoint exposing the DeepResearchAgent via HTTP."""
  2. from __future__ import annotations
  3. import json
  4. import sys
  5. from typing import Any, Dict, Iterator, Optional
  6. from fastapi import FastAPI, HTTPException
  7. from fastapi.middleware.cors import CORSMiddleware
  8. from fastapi.responses import StreamingResponse
  9. from loguru import logger
  10. from pydantic import BaseModel, Field
  11. from config import Configuration, SearchAPI
  12. from agent import DeepResearchAgent
  13. # 添加控制台日志处理程序
  14. logger.add(
  15. sys.stderr,
  16. level="INFO",
  17. format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <4}</level> | <cyan>using_function:{function}</cyan> | <cyan>{file}:{line}</cyan> | <level>{message}</level>",
  18. colorize=True,
  19. )
  20. # 添加错误日志文件处理程序
  21. logger.add(
  22. sink=sys.stderr,
  23. level="ERROR",
  24. format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <4}</level> | <cyan>using_function:{function}</cyan> | <cyan>{file}:{line}</cyan> | <level>{message}</level>",
  25. colorize=True,
  26. )
  27. class ResearchRequest(BaseModel):
  28. """Payload for triggering a research run."""
  29. topic: str = Field(..., description="Research topic supplied by the user")
  30. search_api: SearchAPI | None = Field(
  31. default=None,
  32. description="Override the default search backend configured via env",
  33. )
  34. class ResearchResponse(BaseModel):
  35. """HTTP response containing the generated report and structured tasks."""
  36. report_markdown: str = Field(
  37. ..., description="Markdown-formatted research report including sections"
  38. )
  39. todo_items: list[dict[str, Any]] = Field(
  40. default_factory=list,
  41. description="Structured TODO items with summaries and sources",
  42. )
  43. def _mask_secret(value: Optional[str], visible: int = 4) -> str:
  44. """Mask sensitive tokens while keeping leading and trailing characters."""
  45. if not value:
  46. return "unset"
  47. if len(value) <= visible * 2:
  48. return "*" * len(value)
  49. return f"{value[:visible]}...{value[-visible:]}"
  50. def _build_config(payload: ResearchRequest) -> Configuration:
  51. overrides: Dict[str, Any] = {}
  52. if payload.search_api is not None:
  53. overrides["search_api"] = payload.search_api
  54. return Configuration.from_env(overrides=overrides)
  55. def create_app() -> FastAPI:
  56. app = FastAPI(title="HelloAgents Deep Researcher")
  57. app.add_middleware(
  58. CORSMiddleware,
  59. allow_origins=["*"],
  60. allow_credentials=True,
  61. allow_methods=["*"],
  62. allow_headers=["*"],
  63. )
  64. @app.on_event("startup")
  65. def log_startup_configuration() -> None:
  66. config = Configuration.from_env()
  67. if config.llm_provider == "ollama":
  68. base_url = config.sanitized_ollama_url()
  69. elif config.llm_provider == "lmstudio":
  70. base_url = config.lmstudio_base_url
  71. else:
  72. base_url = config.llm_base_url or "unset"
  73. logger.info(
  74. "DeepResearch configuration loaded: provider=%s model=%s base_url=%s search_api=%s "
  75. "max_loops=%s fetch_full_page=%s tool_calling=%s strip_thinking=%s api_key=%s",
  76. config.llm_provider,
  77. config.resolved_model() or "unset",
  78. base_url,
  79. (config.search_api.value if isinstance(config.search_api, SearchAPI) else config.search_api),
  80. config.max_web_research_loops,
  81. config.fetch_full_page,
  82. config.use_tool_calling,
  83. config.strip_thinking_tokens,
  84. _mask_secret(config.llm_api_key),
  85. )
  86. @app.get("/healthz")
  87. def health_check() -> Dict[str, str]:
  88. return {"status": "ok"}
  89. @app.post("/research", response_model=ResearchResponse)
  90. def run_research(payload: ResearchRequest) -> ResearchResponse:
  91. try:
  92. config = _build_config(payload)
  93. agent = DeepResearchAgent(config=config)
  94. result = agent.run(payload.topic)
  95. except ValueError as exc: # Likely due to unsupported configuration
  96. raise HTTPException(status_code=400, detail=str(exc)) from exc
  97. except Exception as exc: # pragma: no cover - defensive guardrail
  98. raise HTTPException(status_code=500, detail="Research failed") from exc
  99. todo_payload = [
  100. {
  101. "id": item.id,
  102. "title": item.title,
  103. "intent": item.intent,
  104. "query": item.query,
  105. "status": item.status,
  106. "summary": item.summary,
  107. "sources_summary": item.sources_summary,
  108. "note_id": item.note_id,
  109. "note_path": item.note_path,
  110. }
  111. for item in result.todo_items
  112. ]
  113. return ResearchResponse(
  114. report_markdown=(result.report_markdown or result.running_summary or ""),
  115. todo_items=todo_payload,
  116. )
  117. @app.post("/research/stream")
  118. def stream_research(payload: ResearchRequest) -> StreamingResponse:
  119. try:
  120. config = _build_config(payload)
  121. agent = DeepResearchAgent(config=config)
  122. except ValueError as exc:
  123. raise HTTPException(status_code=400, detail=str(exc)) from exc
  124. def event_iterator() -> Iterator[str]:
  125. try:
  126. for event in agent.run_stream(payload.topic):
  127. yield f"data: {json.dumps(event, ensure_ascii=False)}\n\n"
  128. except Exception as exc: # pragma: no cover - defensive guardrail
  129. logger.exception("Streaming research failed")
  130. error_payload = {"type": "error", "detail": str(exc)}
  131. yield f"data: {json.dumps(error_payload, ensure_ascii=False)}\n\n"
  132. return StreamingResponse(
  133. event_iterator(),
  134. media_type="text/event-stream",
  135. headers={
  136. "Cache-Control": "no-cache",
  137. "Connection": "keep-alive",
  138. },
  139. )
  140. return app
  141. app = create_app()
  142. if __name__ == "__main__":
  143. import uvicorn
  144. uvicorn.run(
  145. "main:app",
  146. host="0.0.0.0",
  147. port=8000,
  148. reload=True,
  149. log_level="info"
  150. )