llm.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. """HelloAgents统一LLM接口 - 基于OpenAI原生API"""
  2. import os
  3. from typing import Literal, Optional, Iterator
  4. from openai import OpenAI
  5. from .exceptions import HelloAgentsException
  6. # 支持的LLM提供商
  7. SUPPORTED_PROVIDERS = Literal[
  8. "openai",
  9. "deepseek",
  10. "qwen",
  11. "modelscope",
  12. "kimi",
  13. "zhipu",
  14. "ollama",
  15. "vllm",
  16. "local",
  17. "auto",
  18. "custom",
  19. ]
  20. class HelloAgentsLLM:
  21. """
  22. 为HelloAgents定制的LLM客户端。
  23. 它用于调用任何兼容OpenAI接口的服务,并默认使用流式响应。
  24. 设计理念:
  25. - 参数优先,环境变量兜底
  26. - 流式响应为默认,提供更好的用户体验
  27. - 支持多种LLM提供商
  28. - 统一的调用接口
  29. """
  30. def __init__(
  31. self,
  32. model: Optional[str] = None,
  33. api_key: Optional[str] = None,
  34. base_url: Optional[str] = None,
  35. provider: Optional[SUPPORTED_PROVIDERS] = None,
  36. temperature: float = 0.7,
  37. max_tokens: Optional[int] = None,
  38. timeout: Optional[int] = None,
  39. **kwargs,
  40. ):
  41. """
  42. 初始化客户端。优先使用传入参数,如果未提供,则从环境变量加载。
  43. 支持自动检测provider或使用统一的LLM_*环境变量配置。
  44. Args:
  45. model: 模型名称,如果未提供则从环境变量LLM_MODEL_ID读取
  46. api_key: API密钥,如果未提供则从环境变量读取
  47. base_url: 服务地址,如果未提供则从环境变量LLM_BASE_URL读取
  48. provider: LLM提供商,如果未提供则自动检测
  49. temperature: 温度参数
  50. max_tokens: 最大token数
  51. timeout: 超时时间,从环境变量LLM_TIMEOUT读取,默认60秒
  52. """
  53. # 优先使用传入参数,如果未提供,则从环境变量加载
  54. self.model = model or os.getenv("LLM_MODEL_ID")
  55. self.temperature = temperature
  56. self.max_tokens = max_tokens
  57. self.timeout = timeout or int(os.getenv("LLM_TIMEOUT", "60"))
  58. self.kwargs = kwargs
  59. # 自动检测provider或使用指定的provider
  60. requested_provider = (provider or "").lower() if provider else None
  61. self.provider = provider or self._auto_detect_provider(api_key, base_url)
  62. if requested_provider == "custom":
  63. self.provider = "custom"
  64. self.api_key = api_key or os.getenv("LLM_API_KEY")
  65. self.base_url = base_url or os.getenv("LLM_BASE_URL")
  66. else:
  67. # 根据provider确定API密钥和base_url
  68. self.api_key, self.base_url = self._resolve_credentials(api_key, base_url)
  69. # 验证必要参数
  70. if not self.model:
  71. self.model = self._get_default_model()
  72. if not all([self.api_key, self.base_url]):
  73. raise HelloAgentsException(
  74. "API密钥和服务地址必须被提供或在.env文件中定义。"
  75. )
  76. # 创建OpenAI客户端
  77. self._client = self._create_client()
  78. def _auto_detect_provider(
  79. self, api_key: Optional[str], base_url: Optional[str]
  80. ) -> str:
  81. """
  82. 自动检测LLM提供商
  83. 检测逻辑:
  84. 1. 优先检查特定提供商的环境变量
  85. 2. 根据API密钥格式判断
  86. 3. 根据base_url判断
  87. 4. 默认返回通用配置
  88. """
  89. # 1. 检查特定提供商的环境变量
  90. if os.getenv("OPENAI_API_KEY"):
  91. return "openai"
  92. if os.getenv("DEEPSEEK_API_KEY"):
  93. return "deepseek"
  94. if os.getenv("DASHSCOPE_API_KEY"):
  95. return "qwen"
  96. if os.getenv("MODELSCOPE_API_KEY"):
  97. return "modelscope"
  98. if os.getenv("KIMI_API_KEY") or os.getenv("MOONSHOT_API_KEY"):
  99. return "kimi"
  100. if os.getenv("ZHIPU_API_KEY") or os.getenv("GLM_API_KEY"):
  101. return "zhipu"
  102. if os.getenv("OLLAMA_API_KEY") or os.getenv("OLLAMA_HOST"):
  103. return "ollama"
  104. if os.getenv("VLLM_API_KEY") or os.getenv("VLLM_HOST"):
  105. return "vllm"
  106. # 2. 根据API密钥格式判断
  107. actual_api_key = api_key or os.getenv("LLM_API_KEY")
  108. if actual_api_key:
  109. actual_key_lower = actual_api_key.lower()
  110. if actual_api_key.startswith("ms-"):
  111. return "modelscope"
  112. elif actual_key_lower == "ollama":
  113. return "ollama"
  114. elif actual_key_lower == "vllm":
  115. return "vllm"
  116. elif actual_key_lower == "local":
  117. return "local"
  118. elif actual_api_key.startswith("sk-") and len(actual_api_key) > 50:
  119. # 可能是OpenAI、DeepSeek或Kimi,需要进一步判断
  120. pass
  121. elif actual_api_key.endswith(".") or "." in actual_api_key[-20:]:
  122. # 智谱AI的API密钥格式通常包含点号
  123. return "zhipu"
  124. # 3. 根据base_url判断
  125. actual_base_url = base_url or os.getenv("LLM_BASE_URL")
  126. if actual_base_url:
  127. base_url_lower = actual_base_url.lower()
  128. if "api.openai.com" in base_url_lower:
  129. return "openai"
  130. elif "api.deepseek.com" in base_url_lower:
  131. return "deepseek"
  132. elif "dashscope.aliyuncs.com" in base_url_lower:
  133. return "qwen"
  134. elif "api-inference.modelscope.cn" in base_url_lower:
  135. return "modelscope"
  136. elif "api.moonshot.cn" in base_url_lower:
  137. return "kimi"
  138. elif "open.bigmodel.cn" in base_url_lower:
  139. return "zhipu"
  140. elif "localhost" in base_url_lower or "127.0.0.1" in base_url_lower:
  141. # 本地部署检测 - 优先检查特定服务
  142. if ":11434" in base_url_lower or "ollama" in base_url_lower:
  143. return "ollama"
  144. elif ":8000" in base_url_lower and "vllm" in base_url_lower:
  145. return "vllm"
  146. elif ":8080" in base_url_lower or ":7860" in base_url_lower:
  147. return "local"
  148. else:
  149. # 根据API密钥进一步判断
  150. if actual_api_key and actual_api_key.lower() == "ollama":
  151. return "ollama"
  152. elif actual_api_key and actual_api_key.lower() == "vllm":
  153. return "vllm"
  154. else:
  155. return "local"
  156. elif any(port in base_url_lower for port in [":8080", ":7860", ":5000"]):
  157. # 常见的本地部署端口
  158. return "local"
  159. # 4. 默认返回auto,使用通用配置
  160. return "auto"
  161. def _resolve_credentials(
  162. self, api_key: Optional[str], base_url: Optional[str]
  163. ) -> tuple[str, str]:
  164. """根据provider解析API密钥和base_url"""
  165. if self.provider == "openai":
  166. resolved_api_key = (
  167. api_key or os.getenv("OPENAI_API_KEY") or os.getenv("LLM_API_KEY")
  168. )
  169. resolved_base_url = (
  170. base_url or os.getenv("LLM_BASE_URL") or "https://api.openai.com/v1"
  171. )
  172. return resolved_api_key, resolved_base_url
  173. elif self.provider == "deepseek":
  174. resolved_api_key = (
  175. api_key or os.getenv("DEEPSEEK_API_KEY") or os.getenv("LLM_API_KEY")
  176. )
  177. resolved_base_url = (
  178. base_url or os.getenv("LLM_BASE_URL") or "https://api.deepseek.com"
  179. )
  180. return resolved_api_key, resolved_base_url
  181. elif self.provider == "qwen":
  182. resolved_api_key = (
  183. api_key or os.getenv("DASHSCOPE_API_KEY") or os.getenv("LLM_API_KEY")
  184. )
  185. resolved_base_url = (
  186. base_url
  187. or os.getenv("LLM_BASE_URL")
  188. or "https://dashscope.aliyuncs.com/compatible-mode/v1"
  189. )
  190. return resolved_api_key, resolved_base_url
  191. elif self.provider == "modelscope":
  192. resolved_api_key = (
  193. api_key or os.getenv("MODELSCOPE_API_KEY") or os.getenv("LLM_API_KEY")
  194. )
  195. resolved_base_url = (
  196. base_url
  197. or os.getenv("LLM_BASE_URL")
  198. or "https://api-inference.modelscope.cn/v1/"
  199. )
  200. return resolved_api_key, resolved_base_url
  201. elif self.provider == "kimi":
  202. resolved_api_key = (
  203. api_key
  204. or os.getenv("KIMI_API_KEY")
  205. or os.getenv("MOONSHOT_API_KEY")
  206. or os.getenv("LLM_API_KEY")
  207. )
  208. resolved_base_url = (
  209. base_url or os.getenv("LLM_BASE_URL") or "https://api.moonshot.cn/v1"
  210. )
  211. return resolved_api_key, resolved_base_url
  212. elif self.provider == "zhipu":
  213. resolved_api_key = (
  214. api_key
  215. or os.getenv("ZHIPU_API_KEY")
  216. or os.getenv("GLM_API_KEY")
  217. or os.getenv("LLM_API_KEY")
  218. )
  219. resolved_base_url = (
  220. base_url
  221. or os.getenv("LLM_BASE_URL")
  222. or "https://open.bigmodel.cn/api/paas/v4"
  223. )
  224. return resolved_api_key, resolved_base_url
  225. elif self.provider == "ollama":
  226. resolved_api_key = (
  227. api_key
  228. or os.getenv("OLLAMA_API_KEY")
  229. or os.getenv("LLM_API_KEY")
  230. or "ollama"
  231. )
  232. resolved_base_url = (
  233. base_url
  234. or os.getenv("OLLAMA_HOST")
  235. or os.getenv("LLM_BASE_URL")
  236. or "http://localhost:11434/v1"
  237. )
  238. return resolved_api_key, resolved_base_url
  239. elif self.provider == "vllm":
  240. resolved_api_key = (
  241. api_key
  242. or os.getenv("VLLM_API_KEY")
  243. or os.getenv("LLM_API_KEY")
  244. or "vllm"
  245. )
  246. resolved_base_url = (
  247. base_url
  248. or os.getenv("VLLM_HOST")
  249. or os.getenv("LLM_BASE_URL")
  250. or "http://localhost:8000/v1"
  251. )
  252. return resolved_api_key, resolved_base_url
  253. elif self.provider == "local":
  254. resolved_api_key = api_key or os.getenv("LLM_API_KEY") or "local"
  255. resolved_base_url = (
  256. base_url or os.getenv("LLM_BASE_URL") or "http://localhost:8000/v1"
  257. )
  258. return resolved_api_key, resolved_base_url
  259. elif self.provider == "custom":
  260. resolved_api_key = api_key or os.getenv("LLM_API_KEY")
  261. resolved_base_url = base_url or os.getenv("LLM_BASE_URL")
  262. return resolved_api_key, resolved_base_url
  263. else:
  264. # auto或其他情况:使用通用配置,支持任何OpenAI兼容的服务
  265. resolved_api_key = api_key or os.getenv("LLM_API_KEY")
  266. resolved_base_url = base_url or os.getenv("LLM_BASE_URL")
  267. return resolved_api_key, resolved_base_url
  268. def _create_client(self) -> OpenAI:
  269. """创建OpenAI客户端"""
  270. return OpenAI(
  271. api_key=self.api_key, base_url=self.base_url, timeout=self.timeout
  272. )
  273. def _get_default_model(self) -> str:
  274. """获取默认模型"""
  275. if self.provider == "openai":
  276. return "gpt-3.5-turbo"
  277. elif self.provider == "deepseek":
  278. return "deepseek-chat"
  279. elif self.provider == "qwen":
  280. return "qwen-plus"
  281. elif self.provider == "modelscope":
  282. return "Qwen/Qwen2.5-72B-Instruct"
  283. elif self.provider == "kimi":
  284. return "moonshot-v1-8k"
  285. elif self.provider == "zhipu":
  286. return "glm-4"
  287. elif self.provider == "ollama":
  288. return "llama3.2" # Ollama常用模型
  289. elif self.provider == "vllm":
  290. return "meta-llama/Llama-2-7b-chat-hf" # vLLM常用模型
  291. elif self.provider == "local":
  292. return "local-model" # 本地模型占位符
  293. elif self.provider == "custom":
  294. return self.model or "gpt-3.5-turbo"
  295. else:
  296. # auto或其他情况:根据base_url智能推断默认模型
  297. base_url = os.getenv("LLM_BASE_URL", "")
  298. base_url_lower = base_url.lower()
  299. if "modelscope" in base_url_lower:
  300. return "Qwen/Qwen2.5-72B-Instruct"
  301. elif "deepseek" in base_url_lower:
  302. return "deepseek-chat"
  303. elif "dashscope" in base_url_lower:
  304. return "qwen-plus"
  305. elif "moonshot" in base_url_lower:
  306. return "moonshot-v1-8k"
  307. elif "bigmodel" in base_url_lower:
  308. return "glm-4"
  309. elif "ollama" in base_url_lower or ":11434" in base_url_lower:
  310. return "llama3.2"
  311. elif ":8000" in base_url_lower or "vllm" in base_url_lower:
  312. return "meta-llama/Llama-2-7b-chat-hf"
  313. elif "localhost" in base_url_lower or "127.0.0.1" in base_url_lower:
  314. return "local-model"
  315. else:
  316. return "gpt-3.5-turbo"
  317. def think(
  318. self, messages: list[dict[str, str]], temperature: Optional[float] = None
  319. ) -> Iterator[str]:
  320. """
  321. 调用大语言模型进行思考,并返回流式响应。
  322. 这是主要的调用方法,默认使用流式响应以获得更好的用户体验。
  323. Args:
  324. messages: 消息列表
  325. temperature: 温度参数,如果未提供则使用初始化时的值
  326. Yields:
  327. str: 流式响应的文本片段
  328. """
  329. print(f"🧠 正在调用 {self.model} 模型...")
  330. try:
  331. response = self._client.chat.completions.create(
  332. model=self.model,
  333. messages=messages,
  334. temperature=temperature
  335. if temperature is not None
  336. else self.temperature,
  337. max_tokens=self.max_tokens,
  338. stream=True,
  339. )
  340. # 处理流式响应
  341. print("✅ 大语言模型响应成功:")
  342. for chunk in response:
  343. content = chunk.choices[0].delta.content or ""
  344. if content:
  345. print(content, end="", flush=True)
  346. yield content
  347. print() # 在流式输出结束后换行
  348. except Exception as e:
  349. print(f"❌ 调用LLM API时发生错误: {e}")
  350. raise HelloAgentsException(f"LLM调用失败: {str(e)}")
  351. def invoke(self, messages: list[dict[str, str]], **kwargs) -> str:
  352. """
  353. 非流式调用LLM,返回完整响应。
  354. 适用于不需要流式输出的场景。
  355. """
  356. try:
  357. response = self._client.chat.completions.create(
  358. model=self.model,
  359. messages=messages,
  360. temperature=kwargs.get("temperature", self.temperature),
  361. max_tokens=kwargs.get("max_tokens", self.max_tokens),
  362. **{
  363. k: v
  364. for k, v in kwargs.items()
  365. if k not in ["temperature", "max_tokens"]
  366. },
  367. )
  368. return response.choices[0].message.content
  369. except Exception as e:
  370. raise HelloAgentsException(f"LLM调用失败: {str(e)}")
  371. def stream_invoke(self, messages: list[dict[str, str]], **kwargs) -> Iterator[str]:
  372. """
  373. 流式调用LLM的别名方法,与think方法功能相同。
  374. 保持向后兼容性。
  375. """
  376. temperature = kwargs.get("temperature")
  377. yield from self.think(messages, temperature)