llm.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. """HelloAgents统一LLM接口 - 基于OpenAI原生API"""
  2. import os
  3. from typing import Literal, Optional, Iterator
  4. from openai import OpenAI
  5. from .exceptions import HelloAgentsException
  6. # 支持的LLM提供商
  7. SUPPORTED_PROVIDERS = Literal[
  8. "openai", "deepseek", "qwen", "modelscope",
  9. "kimi", "zhipu", "ollama", "vllm", "local", "auto"
  10. ]
  11. class HelloAgentsLLM:
  12. """
  13. 为HelloAgents定制的LLM客户端。
  14. 它用于调用任何兼容OpenAI接口的服务,并默认使用流式响应。
  15. 设计理念:
  16. - 参数优先,环境变量兜底
  17. - 流式响应为默认,提供更好的用户体验
  18. - 支持多种LLM提供商
  19. - 统一的调用接口
  20. """
  21. def __init__(
  22. self,
  23. model: Optional[str] = None,
  24. api_key: Optional[str] = None,
  25. base_url: Optional[str] = None,
  26. provider: Optional[SUPPORTED_PROVIDERS] = None,
  27. temperature: float = 0.7,
  28. max_tokens: Optional[int] = None,
  29. timeout: Optional[int] = None,
  30. **kwargs
  31. ):
  32. """
  33. 初始化客户端。优先使用传入参数,如果未提供,则从环境变量加载。
  34. 支持自动检测provider或使用统一的LLM_*环境变量配置。
  35. Args:
  36. model: 模型名称,如果未提供则从环境变量LLM_MODEL_ID读取
  37. api_key: API密钥,如果未提供则从环境变量读取
  38. base_url: 服务地址,如果未提供则从环境变量LLM_BASE_URL读取
  39. provider: LLM提供商,如果未提供则自动检测
  40. temperature: 温度参数
  41. max_tokens: 最大token数
  42. timeout: 超时时间,从环境变量LLM_TIMEOUT读取,默认60秒
  43. """
  44. # 优先使用传入参数,如果未提供,则从环境变量加载
  45. self.model = model or os.getenv("LLM_MODEL_ID")
  46. self.temperature = temperature
  47. self.max_tokens = max_tokens
  48. self.timeout = timeout or int(os.getenv("LLM_TIMEOUT", "60"))
  49. self.kwargs = kwargs
  50. # 自动检测provider或使用指定的provider
  51. self.provider = provider or self._auto_detect_provider(api_key, base_url)
  52. # 根据provider确定API密钥和base_url
  53. self.api_key, self.base_url = self._resolve_credentials(api_key, base_url)
  54. # 验证必要参数
  55. if not self.model:
  56. self.model = self._get_default_model()
  57. if not all([self.api_key, self.base_url]):
  58. raise HelloAgentsException("API密钥和服务地址必须被提供或在.env文件中定义。")
  59. # 创建OpenAI客户端
  60. self._client = self._create_client()
  61. def _auto_detect_provider(self, api_key: Optional[str], base_url: Optional[str]) -> str:
  62. """
  63. 自动检测LLM提供商
  64. 检测逻辑:
  65. 1. 优先检查特定提供商的环境变量
  66. 2. 根据API密钥格式判断
  67. 3. 根据base_url判断
  68. 4. 默认返回通用配置
  69. """
  70. # 1. 检查特定提供商的环境变量
  71. if os.getenv("OPENAI_API_KEY"):
  72. return "openai"
  73. if os.getenv("DEEPSEEK_API_KEY"):
  74. return "deepseek"
  75. if os.getenv("DASHSCOPE_API_KEY"):
  76. return "qwen"
  77. if os.getenv("MODELSCOPE_API_KEY"):
  78. return "modelscope"
  79. if os.getenv("KIMI_API_KEY") or os.getenv("MOONSHOT_API_KEY"):
  80. return "kimi"
  81. if os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY"):
  82. return "zhipu"
  83. if os.getenv("OLLAMA_API_KEY") or os.getenv("OLLAMA_HOST"):
  84. return "ollama"
  85. if os.getenv("VLLM_API_KEY") or os.getenv("VLLM_HOST"):
  86. return "vllm"
  87. # 2. 根据API密钥格式判断
  88. actual_api_key = api_key or os.getenv("LLM_API_KEY")
  89. if actual_api_key:
  90. actual_key_lower = actual_api_key.lower()
  91. if actual_api_key.startswith("ms-"):
  92. return "modelscope"
  93. elif actual_key_lower == "ollama":
  94. return "ollama"
  95. elif actual_key_lower == "vllm":
  96. return "vllm"
  97. elif actual_key_lower == "local":
  98. return "local"
  99. elif actual_api_key.startswith("sk-") and len(actual_api_key) > 50:
  100. # 可能是OpenAI、DeepSeek或Kimi,需要进一步判断
  101. pass
  102. elif actual_api_key.endswith(".") or "." in actual_api_key[-20:]:
  103. # 智谱AI的API密钥格式通常包含点号
  104. return "zhipu"
  105. # 3. 根据base_url判断
  106. actual_base_url = base_url or os.getenv("LLM_BASE_URL")
  107. if actual_base_url:
  108. base_url_lower = actual_base_url.lower()
  109. if "api.openai.com" in base_url_lower:
  110. return "openai"
  111. elif "api.deepseek.com" in base_url_lower:
  112. return "deepseek"
  113. elif "dashscope.aliyuncs.com" in base_url_lower:
  114. return "qwen"
  115. elif "api-inference.modelscope.cn" in base_url_lower:
  116. return "modelscope"
  117. elif "api.moonshot.cn" in base_url_lower:
  118. return "kimi"
  119. elif "open.bigmodel.cn" in base_url_lower:
  120. return "zhipu"
  121. elif "localhost" in base_url_lower or "127.0.0.1" in base_url_lower:
  122. # 本地部署检测 - 优先检查特定服务
  123. if ":11434" in base_url_lower or "ollama" in base_url_lower:
  124. return "ollama"
  125. elif ":8000" in base_url_lower and "vllm" in base_url_lower:
  126. return "vllm"
  127. elif ":8080" in base_url_lower or ":7860" in base_url_lower:
  128. return "local"
  129. else:
  130. # 根据API密钥进一步判断
  131. if actual_api_key and actual_api_key.lower() == "ollama":
  132. return "ollama"
  133. elif actual_api_key and actual_api_key.lower() == "vllm":
  134. return "vllm"
  135. else:
  136. return "local"
  137. elif any(port in base_url_lower for port in [":8080", ":7860", ":5000"]):
  138. # 常见的本地部署端口
  139. return "local"
  140. # 4. 默认返回auto,使用通用配置
  141. return "auto"
  142. def _resolve_credentials(self, api_key: Optional[str], base_url: Optional[str]) -> tuple[str, str]:
  143. """根据provider解析API密钥和base_url"""
  144. if self.provider == "openai":
  145. resolved_api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("LLM_API_KEY")
  146. resolved_base_url = base_url or os.getenv("LLM_BASE_URL") or "https://api.openai.com/v1"
  147. return resolved_api_key, resolved_base_url
  148. elif self.provider == "deepseek":
  149. resolved_api_key = api_key or os.getenv("DEEPSEEK_API_KEY") or os.getenv("LLM_API_KEY")
  150. resolved_base_url = base_url or os.getenv("LLM_BASE_URL") or "https://api.deepseek.com"
  151. return resolved_api_key, resolved_base_url
  152. elif self.provider == "qwen":
  153. resolved_api_key = api_key or os.getenv("DASHSCOPE_API_KEY") or os.getenv("LLM_API_KEY")
  154. resolved_base_url = base_url or os.getenv("LLM_BASE_URL") or "https://dashscope.aliyuncs.com/compatible-mode/v1"
  155. return resolved_api_key, resolved_base_url
  156. elif self.provider == "modelscope":
  157. resolved_api_key = api_key or os.getenv("MODELSCOPE_API_KEY") or os.getenv("LLM_API_KEY")
  158. resolved_base_url = base_url or os.getenv("LLM_BASE_URL") or "https://api-inference.modelscope.cn/v1/"
  159. return resolved_api_key, resolved_base_url
  160. elif self.provider == "kimi":
  161. resolved_api_key = api_key or os.getenv("KIMI_API_KEY") or os.getenv("MOONSHOT_API_KEY") or os.getenv("LLM_API_KEY")
  162. resolved_base_url = base_url or os.getenv("LLM_BASE_URL") or "https://api.moonshot.cn/v1"
  163. return resolved_api_key, resolved_base_url
  164. elif self.provider == "zhipu":
  165. resolved_api_key = api_key or os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY") or os.getenv("LLM_API_KEY")
  166. resolved_base_url = base_url or os.getenv("LLM_BASE_URL") or "https://open.bigmodel.cn/api/paas/v4"
  167. return resolved_api_key, resolved_base_url
  168. elif self.provider == "ollama":
  169. resolved_api_key = api_key or os.getenv("OLLAMA_API_KEY") or os.getenv("LLM_API_KEY") or "ollama"
  170. resolved_base_url = base_url or os.getenv("OLLAMA_HOST") or os.getenv("LLM_BASE_URL") or "http://localhost:11434/v1"
  171. return resolved_api_key, resolved_base_url
  172. elif self.provider == "vllm":
  173. resolved_api_key = api_key or os.getenv("VLLM_API_KEY") or os.getenv("LLM_API_KEY") or "vllm"
  174. resolved_base_url = base_url or os.getenv("VLLM_HOST") or os.getenv("LLM_BASE_URL") or "http://localhost:8000/v1"
  175. return resolved_api_key, resolved_base_url
  176. elif self.provider == "local":
  177. resolved_api_key = api_key or os.getenv("LLM_API_KEY") or "local"
  178. resolved_base_url = base_url or os.getenv("LLM_BASE_URL") or "http://localhost:8000/v1"
  179. return resolved_api_key, resolved_base_url
  180. else:
  181. # auto或其他情况:使用通用配置,支持任何OpenAI兼容的服务
  182. resolved_api_key = api_key or os.getenv("LLM_API_KEY")
  183. resolved_base_url = base_url or os.getenv("LLM_BASE_URL")
  184. return resolved_api_key, resolved_base_url
  185. def _create_client(self) -> OpenAI:
  186. """创建OpenAI客户端"""
  187. return OpenAI(
  188. api_key=self.api_key,
  189. base_url=self.base_url,
  190. timeout=self.timeout
  191. )
  192. def _get_default_model(self) -> str:
  193. """获取默认模型"""
  194. if self.provider == "openai":
  195. return "gpt-3.5-turbo"
  196. elif self.provider == "deepseek":
  197. return "deepseek-chat"
  198. elif self.provider == "qwen":
  199. return "qwen-plus"
  200. elif self.provider == "modelscope":
  201. return "Qwen/Qwen2.5-72B-Instruct"
  202. elif self.provider == "kimi":
  203. return "moonshot-v1-8k"
  204. elif self.provider == "zhipu":
  205. return "glm-4"
  206. elif self.provider == "ollama":
  207. return "llama3.2" # Ollama常用模型
  208. elif self.provider == "vllm":
  209. return "meta-llama/Llama-2-7b-chat-hf" # vLLM常用模型
  210. elif self.provider == "local":
  211. return "local-model" # 本地模型占位符
  212. else:
  213. # auto或其他情况:根据base_url智能推断默认模型
  214. base_url = os.getenv("LLM_BASE_URL", "")
  215. base_url_lower = base_url.lower()
  216. if "modelscope" in base_url_lower:
  217. return "Qwen/Qwen2.5-72B-Instruct"
  218. elif "deepseek" in base_url_lower:
  219. return "deepseek-chat"
  220. elif "dashscope" in base_url_lower:
  221. return "qwen-plus"
  222. elif "moonshot" in base_url_lower:
  223. return "moonshot-v1-8k"
  224. elif "bigmodel" in base_url_lower:
  225. return "glm-4"
  226. elif "ollama" in base_url_lower or ":11434" in base_url_lower:
  227. return "llama3.2"
  228. elif ":8000" in base_url_lower or "vllm" in base_url_lower:
  229. return "meta-llama/Llama-2-7b-chat-hf"
  230. elif "localhost" in base_url_lower or "127.0.0.1" in base_url_lower:
  231. return "local-model"
  232. else:
  233. return "gpt-3.5-turbo"
  234. def think(self, messages: list[dict[str, str]], temperature: Optional[float] = None) -> Iterator[str]:
  235. """
  236. 调用大语言模型进行思考,并返回流式响应。
  237. 这是主要的调用方法,默认使用流式响应以获得更好的用户体验。
  238. Args:
  239. messages: 消息列表
  240. temperature: 温度参数,如果未提供则使用初始化时的值
  241. Yields:
  242. str: 流式响应的文本片段
  243. """
  244. print(f"🧠 正在调用 {self.model} 模型...")
  245. try:
  246. response = self._client.chat.completions.create(
  247. model=self.model,
  248. messages=messages,
  249. temperature=temperature if temperature is not None else self.temperature,
  250. max_tokens=self.max_tokens,
  251. stream=True,
  252. )
  253. # 处理流式响应
  254. print("✅ 大语言模型响应成功:")
  255. for chunk in response:
  256. content = chunk.choices[0].delta.content or ""
  257. if content:
  258. print(content, end="", flush=True)
  259. yield content
  260. print() # 在流式输出结束后换行
  261. except Exception as e:
  262. print(f"❌ 调用LLM API时发生错误: {e}")
  263. raise HelloAgentsException(f"LLM调用失败: {str(e)}")
  264. def invoke(self, messages: list[dict[str, str]], **kwargs) -> str:
  265. """
  266. 非流式调用LLM,返回完整响应。
  267. 适用于不需要流式输出的场景。
  268. """
  269. try:
  270. response = self._client.chat.completions.create(
  271. model=self.model,
  272. messages=messages,
  273. temperature=kwargs.get('temperature', self.temperature),
  274. max_tokens=kwargs.get('max_tokens', self.max_tokens),
  275. **{k: v for k, v in kwargs.items() if k not in ['temperature', 'max_tokens']}
  276. )
  277. return response.choices[0].message.content
  278. except Exception as e:
  279. raise HelloAgentsException(f"LLM调用失败: {str(e)}")
  280. def stream_invoke(self, messages: list[dict[str, str]], **kwargs) -> Iterator[str]:
  281. """
  282. 流式调用LLM的别名方法,与think方法功能相同。
  283. 保持向后兼容性。
  284. """
  285. temperature = kwargs.get('temperature')
  286. yield from self.think(messages, temperature)