config.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. from enum import Enum
  3. from typing import Any, Optional
  4. from pydantic import BaseModel, Field
  5. class SearchAPI(Enum):
  6. PERPLEXITY = "perplexity"
  7. TAVILY = "tavily"
  8. DUCKDUCKGO = "duckduckgo"
  9. SEARXNG = "searxng"
  10. ADVANCED = "advanced"
  11. class Configuration(BaseModel):
  12. """Configuration options for the deep research assistant."""
  13. max_web_research_loops: int = Field(
  14. default=3,
  15. title="Research Depth",
  16. description="Number of research iterations to perform",
  17. )
  18. local_llm: str = Field(
  19. default="llama3.2",
  20. title="Local Model Name",
  21. description="Name of the locally hosted LLM (Ollama/LMStudio)",
  22. )
  23. llm_provider: str = Field(
  24. default="ollama",
  25. title="LLM Provider",
  26. description="Provider identifier (ollama, lmstudio, or custom)",
  27. )
  28. search_api: SearchAPI = Field(
  29. default=SearchAPI.DUCKDUCKGO,
  30. title="Search API",
  31. description="Web search API to use",
  32. )
  33. enable_notes: bool = Field(
  34. default=True,
  35. title="Enable Notes",
  36. description="Whether to store task progress in NoteTool",
  37. )
  38. notes_workspace: str = Field(
  39. default="./notes",
  40. title="Notes Workspace",
  41. description="Directory for NoteTool to persist task notes",
  42. )
  43. fetch_full_page: bool = Field(
  44. default=True,
  45. title="Fetch Full Page",
  46. description="Include the full page content in the search results",
  47. )
  48. ollama_base_url: str = Field(
  49. default="http://localhost:11434",
  50. title="Ollama Base URL",
  51. description="Base URL for Ollama API (without /v1 suffix)",
  52. )
  53. lmstudio_base_url: str = Field(
  54. default="http://localhost:1234/v1",
  55. title="LMStudio Base URL",
  56. description="Base URL for LMStudio OpenAI-compatible API",
  57. )
  58. strip_thinking_tokens: bool = Field(
  59. default=True,
  60. title="Strip Thinking Tokens",
  61. description="Whether to strip <think> tokens from model responses",
  62. )
  63. use_tool_calling: bool = Field(
  64. default=False,
  65. title="Use Tool Calling",
  66. description="Use tool calling instead of JSON mode for structured output",
  67. )
  68. llm_api_key: Optional[str] = Field(
  69. default=None,
  70. title="LLM API Key",
  71. description="Optional API key when using custom OpenAI-compatible services",
  72. )
  73. llm_base_url: Optional[str] = Field(
  74. default=None,
  75. title="LLM Base URL",
  76. description="Optional base URL when using custom OpenAI-compatible services",
  77. )
  78. llm_model_id: Optional[str] = Field(
  79. default=None,
  80. title="LLM Model ID",
  81. description="Optional model identifier for custom OpenAI-compatible services",
  82. )
  83. @classmethod
  84. def from_env(cls, overrides: Optional[dict[str, Any]] = None) -> "Configuration":
  85. """Create a configuration object using environment variables and overrides."""
  86. raw_values: dict[str, Any] = {}
  87. # Load values from environment variables based on field names
  88. for field_name in cls.model_fields.keys():
  89. env_key = field_name.upper()
  90. if env_key in os.environ:
  91. raw_values[field_name] = os.environ[env_key]
  92. # Additional mappings for explicit env names
  93. env_aliases = {
  94. "local_llm": os.getenv("LOCAL_LLM"),
  95. "llm_provider": os.getenv("LLM_PROVIDER"),
  96. "llm_api_key": os.getenv("LLM_API_KEY"),
  97. "llm_model_id": os.getenv("LLM_MODEL_ID"),
  98. "llm_base_url": os.getenv("LLM_BASE_URL"),
  99. "lmstudio_base_url": os.getenv("LMSTUDIO_BASE_URL"),
  100. "ollama_base_url": os.getenv("OLLAMA_BASE_URL"),
  101. "max_web_research_loops": os.getenv("MAX_WEB_RESEARCH_LOOPS"),
  102. "fetch_full_page": os.getenv("FETCH_FULL_PAGE"),
  103. "strip_thinking_tokens": os.getenv("STRIP_THINKING_TOKENS"),
  104. "use_tool_calling": os.getenv("USE_TOOL_CALLING"),
  105. "search_api": os.getenv("SEARCH_API"),
  106. "enable_notes": os.getenv("ENABLE_NOTES"),
  107. "notes_workspace": os.getenv("NOTES_WORKSPACE"),
  108. }
  109. for key, value in env_aliases.items():
  110. if value is not None:
  111. raw_values.setdefault(key, value)
  112. if overrides:
  113. for key, value in overrides.items():
  114. if value is not None:
  115. raw_values[key] = value
  116. return cls(**raw_values)
  117. def sanitized_ollama_url(self) -> str:
  118. """Ensure Ollama base URL includes the /v1 suffix required by OpenAI clients."""
  119. base = self.ollama_base_url.rstrip("/")
  120. if not base.endswith("/v1"):
  121. base = f"{base}/v1"
  122. return base
  123. def resolved_model(self) -> Optional[str]:
  124. """Best-effort resolution of the model identifier to use."""
  125. return self.llm_model_id or self.local_llm