app.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. """StockInsightAgent — Gradio 前端"""
  2. import threading
  3. import queue
  4. import gradio as gr
  5. from framework_agent import FrameworkStockAgent
  6. from memory import (
  7. memory_get_watchlist, memory_add_watchlist, memory_remove_watchlist,
  8. memory_get_history, memory_get_preferences,
  9. )
  10. from rag import rag_import, rag_stats
  11. _agent = None
  12. def get_agent():
  13. global _agent
  14. if _agent is None:
  15. _agent = FrameworkStockAgent()
  16. return _agent
  17. def _run_with_capture(q: queue.Queue, agent, mode: str, msg: str):
  18. import io, sys
  19. import contextlib
  20. buffer = io.StringIO()
  21. class QueueWriter:
  22. def __init__(self, original_stdout):
  23. self.original = original_stdout
  24. # Ensure _local exists for threads
  25. import threading
  26. self._local = threading.local()
  27. @property
  28. def is_active(self):
  29. return getattr(self._local, "active", False)
  30. @is_active.setter
  31. def is_active(self, value):
  32. self._local.active = value
  33. def write(self, s):
  34. if self.is_active:
  35. buffer.write(s)
  36. q.put(s)
  37. else:
  38. self.original.write(s)
  39. def flush(self):
  40. if not self.is_active:
  41. self.original.flush()
  42. if not isinstance(sys.stdout, QueueWriter):
  43. sys.stdout = QueueWriter(sys.stdout)
  44. sys.stdout.is_active = True
  45. try:
  46. if mode == "深度分析 (PlanSolve)":
  47. result = agent.plan_solve(msg)
  48. elif mode == "批判分析 (Reflection)":
  49. result = agent.reflect(msg)
  50. else:
  51. result = agent.react(msg)
  52. except Exception as e:
  53. result = f"分析出错: {e}"
  54. finally:
  55. sys.stdout.is_active = False
  56. q.put(None)
  57. q.result = result or ""
  58. def respond_stream(message: str, history: list, mode: str, agent=None):
  59. if agent is None:
  60. agent = get_agent()
  61. if not message or not message.strip():
  62. yield history, "", agent
  63. return
  64. msg = message.strip()
  65. history = history or []
  66. # ── 快捷命令 ──
  67. quick = {
  68. ("帮助", "help", "?"): lambda: HELP_TEXT,
  69. ("列表", "关注列表"): memory_get_watchlist,
  70. ("历史",): memory_get_history,
  71. ("偏好",): memory_get_preferences,
  72. ("知识库",): rag_stats,
  73. }
  74. for keys, handler in quick.items():
  75. if msg in keys:
  76. history.append({"role": "user", "content": msg})
  77. history.append({"role": "assistant", "content": handler()})
  78. yield history, "", agent
  79. return
  80. if msg.startswith("关注 "):
  81. parts = msg[3:].strip().split()
  82. c, n = parts[0], (parts[1] if len(parts) > 1 else "")
  83. history.append({"role": "user", "content": msg})
  84. history.append({"role": "assistant", "content": memory_add_watchlist(f"{c}|{n}")})
  85. yield history, "", agent
  86. return
  87. if msg.startswith("移除 "):
  88. history.append({"role": "user", "content": msg})
  89. history.append({"role": "assistant", "content": memory_remove_watchlist(msg[3:].strip())})
  90. yield history, "", agent
  91. return
  92. if msg.startswith("历史 "):
  93. history.append({"role": "user", "content": msg})
  94. history.append({"role": "assistant", "content": memory_get_history(msg[3:].strip())})
  95. yield history, "", agent
  96. return
  97. if msg.startswith("导入 "):
  98. history.append({"role": "user", "content": msg})
  99. history.append({"role": "assistant", "content": rag_import(msg[3:].strip()) + "\n" + rag_stats()})
  100. yield history, "", agent
  101. return
  102. # ── 流式分析 ──
  103. history.append({"role": "user", "content": msg})
  104. history.append({"role": "assistant", "content": "..."})
  105. q = queue.Queue()
  106. t = threading.Thread(target=_run_with_capture, args=(q, agent, mode, msg), daemon=True)
  107. t.start()
  108. collected = []
  109. while True:
  110. try:
  111. chunk = q.get(timeout=0.3)
  112. except queue.Empty:
  113. if collected:
  114. history[-1]["content"] = "".join(collected)
  115. yield history, "", agent
  116. continue
  117. if chunk is None:
  118. break
  119. collected.append(chunk)
  120. history[-1]["content"] = "".join(collected)
  121. yield history, "", agent
  122. final = getattr(q, 'result', '') or "".join(collected)
  123. history[-1]["content"] = str(final)
  124. yield history, "", agent
  125. HELP_TEXT = """## 使用指南
  126. ### 股票分析
  127. 直接输入:`分析贵州茅台600519的估值和风险`
  128. ### 关注管理
  129. | 命令 | 说明 |
  130. |------|------|
  131. | `列表` | 查看关注列表 |
  132. | `关注 600519 茅台` | 添加关注 |
  133. | `移除 600519` | 移除关注 |
  134. ### 数据查询
  135. | 命令 | 说明 |
  136. |------|------|
  137. | `历史` | 全部分析历史 |
  138. | `偏好` | 用户偏好设置 |
  139. | `知识库` | 知识库状态 |"""
  140. # ===== 自定义 CSS =====
  141. CUSTOM_CSS = """
  142. /* 全局 */
  143. .gradio-container {
  144. max-width: 100% !important;
  145. margin: 0 !important;
  146. padding: 0 !important;
  147. font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "PingFang SC", "Microsoft YaHei", sans-serif !important;
  148. }
  149. /* 隐藏默认 footer */
  150. footer { display: none !important; }
  151. /* Header */
  152. .header-wrap {
  153. background: linear-gradient(135deg, #0f1729 0%, #1a2744 50%, #0d2137 100%);
  154. border-bottom: 2px solid #2a5c8a;
  155. padding: 16px 32px;
  156. }
  157. .header-wrap h1 {
  158. font-size: 26px;
  159. font-weight: 700;
  160. color: #e8f0fe;
  161. margin: 0;
  162. letter-spacing: -0.5px;
  163. }
  164. .header-wrap .subtitle {
  165. font-size: 13px;
  166. color: #7b9bcb;
  167. margin-top: 4px;
  168. }
  169. .header-wrap .status-row {
  170. display: flex;
  171. gap: 16px;
  172. margin-top: 10px;
  173. flex-wrap: wrap;
  174. }
  175. .status-badge {
  176. display: inline-flex;
  177. align-items: center;
  178. gap: 6px;
  179. font-size: 11px;
  180. padding: 3px 10px;
  181. border-radius: 12px;
  182. font-weight: 500;
  183. }
  184. .status-badge.online {
  185. background: rgba(34, 197, 94, 0.15);
  186. color: #4ade80;
  187. }
  188. .status-badge.data {
  189. background: rgba(59, 130, 246, 0.15);
  190. color: #60a5fa;
  191. }
  192. /* 侧边栏卡片 */
  193. .sidebar-card {
  194. background: rgba(255,255,255,0.04);
  195. border: 1px solid rgba(255,255,255,0.08);
  196. border-radius: 10px;
  197. padding: 16px;
  198. margin-bottom: 12px;
  199. }
  200. .sidebar-card h4 {
  201. font-size: 12px;
  202. font-weight: 600;
  203. color: #7b9bcb;
  204. text-transform: uppercase;
  205. letter-spacing: 0.5px;
  206. margin: 0 0 10px 0;
  207. }
  208. /* 主对话区域 */
  209. .main-chat {
  210. border-radius: 12px !important;
  211. border: 1px solid rgba(255,255,255,0.1) !important;
  212. background: rgba(255,255,255,0.02) !important;
  213. }
  214. /* 输入框 */
  215. .input-box textarea {
  216. border-radius: 10px !important;
  217. border: 1px solid rgba(255,255,255,0.12) !important;
  218. background: rgba(255,255,255,0.04) !important;
  219. color: #e0e0e0 !important;
  220. font-size: 14px !important;
  221. padding: 12px 16px !important;
  222. }
  223. .input-box textarea::placeholder {
  224. color: rgba(255,255,255,0.3) !important;
  225. }
  226. /* 按钮 */
  227. button.primary {
  228. background: linear-gradient(135deg, #2563eb, #1d4ed8) !important;
  229. border: none !important;
  230. border-radius: 10px !important;
  231. color: white !important;
  232. font-weight: 600 !important;
  233. padding: 12px 24px !important;
  234. transition: all 0.2s !important;
  235. height: 100% !important;
  236. min-height: 44px !important;
  237. }
  238. button.primary:hover {
  239. background: linear-gradient(135deg, #3b82f6, #2563eb) !important;
  240. box-shadow: 0 4px 12px rgba(37,99,235,0.3) !important;
  241. }
  242. button.secondary {
  243. background: rgba(255,255,255,0.04) !important;
  244. border: 1px solid rgba(255,255,255,0.08) !important;
  245. border-radius: 8px !important;
  246. color: #6aaff7 !important;
  247. font-size: 12px !important;
  248. padding: 8px 14px !important;
  249. transition: all 0.2s !important;
  250. width: 100% !important;
  251. text-align: left !important;
  252. }
  253. button.secondary:hover {
  254. background: rgba(255,255,255,0.08) !important;
  255. color: #c8d6e5 !important;
  256. border-color: rgba(255,255,255,0.18) !important;
  257. }
  258. /* Radio 模式选择 */
  259. .mode-radio-wrap {
  260. background: rgba(255,255,255,0.04);
  261. border-radius: 10px;
  262. padding: 14px 16px;
  263. border: 1px solid rgba(255,255,255,0.08);
  264. }
  265. /* --- 高对比度修复 --- */
  266. /* Radio / Checkbox 标签 — 亮蓝色 */
  267. .radio-option label, .radio-option span,
  268. label:has(input[type="radio"]), .radio-label,
  269. fieldset label, .radio-wrap label {
  270. color: #6aaff7 !important;
  271. }
  272. /* Radio hover — 亮灰色 */
  273. .radio-option:hover label, .radio-option:hover span,
  274. fieldset label:hover, .radio-wrap:hover label {
  275. color: #c8d6e5 !important;
  276. }
  277. /* Radio 选中 — 加粗变白 */
  278. input[type="radio"]:checked + label,
  279. input[type="radio"]:checked ~ span {
  280. color: #ffffff !important;
  281. font-weight: 700 !important;
  282. }
  283. /* 输入框 */
  284. input[type="text"], textarea, .input-box textarea {
  285. background: #1a2236 !important;
  286. border: 1px solid #3a5078 !important;
  287. color: #e8edf5 !important;
  288. border-radius: 10px !important;
  289. padding: 12px 16px !important;
  290. font-size: 14px !important;
  291. }
  292. input[type="text"]:focus, textarea:focus {
  293. border-color: #4a8cf7 !important;
  294. box-shadow: 0 0 0 3px rgba(74, 140, 247, 0.15) !important;
  295. outline: none !important;
  296. }
  297. input[type="text"]::placeholder, textarea::placeholder {
  298. color: #5a7099 !important;
  299. }
  300. /* select 下拉 */
  301. select, .dropdown {
  302. background: #1a2236 !important;
  303. color: #d0d8e8 !important;
  304. border: 1px solid #3a5078 !important;
  305. }
  306. /* 链接 */
  307. a, .examples a {
  308. color: #7aabf7 !important;
  309. }
  310. a:hover {
  311. color: #a0c4ff !important;
  312. }
  313. /* 聊天气泡内容 */
  314. .message-row .message {
  315. color: #e4eaf5 !important;
  316. }
  317. /* 聊天气泡 */
  318. .bubble-wrap { border-radius: 12px !important; }
  319. /* 快捷图标 */
  320. .quick-icon {
  321. font-size: 16px;
  322. margin-right: 6px;
  323. }
  324. /* 滚动条 */
  325. ::-webkit-scrollbar { width: 5px; }
  326. ::-webkit-scrollbar-track { background: transparent; }
  327. ::-webkit-scrollbar-thumb { background: rgba(255,255,255,0.1); border-radius: 3px; }
  328. ::-webkit-scrollbar-thumb:hover { background: rgba(255,255,255,0.2); }
  329. """
  330. # ===== 界面 =====
  331. with gr.Blocks(title="StockInsightAgent") as app:
  332. # ── Header ──
  333. gr.HTML("""
  334. <div class="header-wrap">
  335. <h1>StockInsightAgent</h1>
  336. <div class="subtitle">智能股票分析助手</div>
  337. <div class="status-row">
  338. <span class="status-badge online">&#9679; 系统就绪</span>
  339. <span class="status-badge data">&#9679; 数据源: 东方财富 / Sina / 腾讯</span>
  340. </div>
  341. </div>
  342. """)
  343. with gr.Row(equal_height=True):
  344. # ── 左侧栏 ──
  345. with gr.Column(scale=1, min_width=200):
  346. gr.HTML('<div class="sidebar-card"><h4>分析模式</h4></div>')
  347. mode_radio = gr.Radio(
  348. choices=["快速分析 (ReAct)", "深度分析 (PlanSolve)", "批判分析 (Reflection)"],
  349. value="快速分析 (ReAct)",
  350. label="",
  351. interactive=True,
  352. )
  353. gr.HTML('<div class="sidebar-card"><h4>快捷操作</h4></div>')
  354. btn_watchlist = gr.Button("关注列表", elem_classes="secondary")
  355. btn_history = gr.Button("分析历史", elem_classes="secondary")
  356. btn_kb = gr.Button("知识库状态", elem_classes="secondary")
  357. btn_prefs = gr.Button("用户偏好", elem_classes="secondary")
  358. gr.HTML("""
  359. <div style="margin-top:16px; font-size:11px; color:#5a7a9a; line-height:1.6;">
  360. 输入 <b>帮助</b> 查看更多命令<br>
  361. </div>
  362. """)
  363. # ── 主区域 ──
  364. with gr.Column(scale=4):
  365. agent_state = gr.State(None)
  366. chatbot = gr.Chatbot(
  367. label="",
  368. height=520,
  369. elem_classes="main-chat",
  370. placeholder="<div style='text-align:center; color:#6a8aaa; padding-top:80px;'>"
  371. "<div style='font-size:48px; margin-bottom:16px;'>📊</div>"
  372. "<div style='font-size:16px; font-weight:600;'>开始分析你的投资组合</div>"
  373. "<div style='font-size:13px; margin-top:8px;'>输入股票代码或名称,获取全方位分析报告</div>"
  374. "</div>",
  375. )
  376. with gr.Row(equal_height=True):
  377. msg_input = gr.Textbox(
  378. placeholder="输入分析问题...",
  379. label="",
  380. scale=6,
  381. elem_classes="input-box",
  382. )
  383. submit_btn = gr.Button("开始分析", variant="primary", elem_classes="primary", scale=1)
  384. # ── 事件绑定 ──
  385. msg_input.submit(
  386. fn=respond_stream,
  387. inputs=[msg_input, chatbot, mode_radio, agent_state],
  388. outputs=[chatbot, msg_input, agent_state],
  389. )
  390. submit_btn.click(
  391. fn=respond_stream,
  392. inputs=[msg_input, chatbot, mode_radio, agent_state],
  393. outputs=[chatbot, msg_input, agent_state],
  394. )
  395. def quick_action(action, history, agent):
  396. for result in respond_stream(action, history, "快速分析 (ReAct)", agent):
  397. pass
  398. return result[0], result[2]
  399. btn_watchlist.click(lambda h, a: quick_action("列表", h, a), [chatbot, agent_state], [chatbot, agent_state])
  400. btn_history.click(lambda h, a: quick_action("历史", h, a), [chatbot, agent_state], [chatbot, agent_state])
  401. btn_kb.click(lambda h, a: quick_action("知识库", h, a), [chatbot, agent_state], [chatbot, agent_state])
  402. btn_prefs.click(lambda h, a: quick_action("偏好", h, a), [chatbot, agent_state], [chatbot, agent_state])
  403. if __name__ == "__main__":
  404. app.launch(
  405. server_name="127.0.0.1", server_port=7861, share=False,
  406. css=CUSTOM_CSS,
  407. theme=gr.themes.Soft(primary_hue="blue", neutral_hue="slate"),
  408. )