Ver código fonte

fix:按照copilot的意见进行修改

CC1227871 1 mês atrás
pai
commit
fba72db21a

+ 55 - 28
Co-creation-projects/CC1227871-StockInsightAgent/app.py

@@ -21,17 +21,39 @@ def get_agent():
 
 def _run_with_capture(q: queue.Queue, agent, mode: str, msg: str):
     import io, sys
+    import contextlib
     buffer = io.StringIO()
 
     class QueueWriter:
+        def __init__(self, original_stdout):
+            self.original = original_stdout
+            # Ensure _local exists for threads
+            import threading
+            self._local = threading.local()
+
+        @property
+        def is_active(self):
+            return getattr(self._local, "active", False)
+
+        @is_active.setter
+        def is_active(self, value):
+            self._local.active = value
+
         def write(self, s):
-            buffer.write(s)
-            q.put(s)
+            if self.is_active:
+                buffer.write(s)
+                q.put(s)
+            else:
+                self.original.write(s)
+
         def flush(self):
-            pass
+            if not self.is_active:
+                self.original.flush()
+
+    if not isinstance(sys.stdout, QueueWriter):
+        sys.stdout = QueueWriter(sys.stdout)
 
-    old = sys.stdout
-    sys.stdout = QueueWriter()
+    sys.stdout.is_active = True
     try:
         if mode == "深度分析 (PlanSolve)":
             result = agent.plan_solve(msg)
@@ -42,14 +64,17 @@ def _run_with_capture(q: queue.Queue, agent, mode: str, msg: str):
     except Exception as e:
         result = f"分析出错: {e}"
     finally:
-        sys.stdout = old
+        sys.stdout.is_active = False
     q.put(None)
     q.result = result or ""
 
 
-def respond_stream(message: str, history: list, mode: str):
+def respond_stream(message: str, history: list, mode: str, agent=None):
+    if agent is None:
+        agent = get_agent()
+
     if not message or not message.strip():
-        yield history, ""
+        yield history, "", agent
         return
 
     msg = message.strip()
@@ -67,7 +92,7 @@ def respond_stream(message: str, history: list, mode: str):
         if msg in keys:
             history.append({"role": "user", "content": msg})
             history.append({"role": "assistant", "content": handler()})
-            yield history, ""
+            yield history, "", agent
             return
 
     if msg.startswith("关注 "):
@@ -75,25 +100,25 @@ def respond_stream(message: str, history: list, mode: str):
         c, n = parts[0], (parts[1] if len(parts) > 1 else "")
         history.append({"role": "user", "content": msg})
         history.append({"role": "assistant", "content": memory_add_watchlist(f"{c}|{n}")})
-        yield history, ""
+        yield history, "", agent
         return
 
     if msg.startswith("移除 "):
         history.append({"role": "user", "content": msg})
         history.append({"role": "assistant", "content": memory_remove_watchlist(msg[3:].strip())})
-        yield history, ""
+        yield history, "", agent
         return
 
     if msg.startswith("历史 "):
         history.append({"role": "user", "content": msg})
         history.append({"role": "assistant", "content": memory_get_history(msg[3:].strip())})
-        yield history, ""
+        yield history, "", agent
         return
 
     if msg.startswith("导入 "):
         history.append({"role": "user", "content": msg})
         history.append({"role": "assistant", "content": rag_import(msg[3:].strip()) + "\n" + rag_stats()})
-        yield history, ""
+        yield history, "", agent
         return
 
     # ── 流式分析 ──
@@ -101,7 +126,7 @@ def respond_stream(message: str, history: list, mode: str):
     history.append({"role": "assistant", "content": "..."})
 
     q = queue.Queue()
-    t = threading.Thread(target=_run_with_capture, args=(q, get_agent(), mode, msg), daemon=True)
+    t = threading.Thread(target=_run_with_capture, args=(q, agent, mode, msg), daemon=True)
     t.start()
 
     collected = []
@@ -111,17 +136,17 @@ def respond_stream(message: str, history: list, mode: str):
         except queue.Empty:
             if collected:
                 history[-1]["content"] = "".join(collected)
-            yield history, ""
+            yield history, "", agent
             continue
         if chunk is None:
             break
         collected.append(chunk)
         history[-1]["content"] = "".join(collected)
-        yield history, ""
+        yield history, "", agent
 
     final = getattr(q, 'result', '') or "".join(collected)
     history[-1]["content"] = str(final)
-    yield history, ""
+    yield history, "", agent
 
 
 HELP_TEXT = """## 使用指南
@@ -390,6 +415,8 @@ with gr.Blocks(title="StockInsightAgent") as app:
 
         # ── 主区域 ──
         with gr.Column(scale=4):
+            agent_state = gr.State(None)
+
             chatbot = gr.Chatbot(
                 label="",
                 height=520,
@@ -413,24 +440,24 @@ with gr.Blocks(title="StockInsightAgent") as app:
     # ── 事件绑定 ──
     msg_input.submit(
         fn=respond_stream,
-        inputs=[msg_input, chatbot, mode_radio],
-        outputs=[chatbot, msg_input],
+        inputs=[msg_input, chatbot, mode_radio, agent_state],
+        outputs=[chatbot, msg_input, agent_state],
     )
     submit_btn.click(
         fn=respond_stream,
-        inputs=[msg_input, chatbot, mode_radio],
-        outputs=[chatbot, msg_input],
+        inputs=[msg_input, chatbot, mode_radio, agent_state],
+        outputs=[chatbot, msg_input, agent_state],
     )
 
-    def quick_action(action, history):
-        for result in respond_stream(action, history, "快速分析 (ReAct)"):
+    def quick_action(action, history, agent):
+        for result in respond_stream(action, history, "快速分析 (ReAct)", agent):
             pass
-        return result[0]
+        return result[0], result[2]
 
-    btn_watchlist.click(lambda h: quick_action("列表", h), [chatbot], [chatbot])
-    btn_history.click(lambda h: quick_action("历史", h), [chatbot], [chatbot])
-    btn_kb.click(lambda h: quick_action("知识库", h), [chatbot], [chatbot])
-    btn_prefs.click(lambda h: quick_action("偏好", h), [chatbot], [chatbot])
+    btn_watchlist.click(lambda h, a: quick_action("列表", h, a), [chatbot, agent_state], [chatbot, agent_state])
+    btn_history.click(lambda h, a: quick_action("历史", h, a), [chatbot, agent_state], [chatbot, agent_state])
+    btn_kb.click(lambda h, a: quick_action("知识库", h, a), [chatbot, agent_state], [chatbot, agent_state])
+    btn_prefs.click(lambda h, a: quick_action("偏好", h, a), [chatbot, agent_state], [chatbot, agent_state])
 
 if __name__ == "__main__":
     app.launch(

+ 8 - 2
Co-creation-projects/CC1227871-StockInsightAgent/context_manager.py

@@ -63,8 +63,14 @@ class ContextManager:
             self.summary = new_summary
 
         # 限制摘要长度
-        if self._estimate_tokens(self.summary) > 1500:
-            self.summary = self.summary[-1500:]
+        while self._estimate_tokens(self.summary) > 1500:
+            # Drop the earliest part of the summary string by splitting on lines
+            lines = self.summary.split('\n')
+            if len(lines) <= 2:
+                # If there are only a couple lines left, we must chop characters
+                self.summary = self.summary[int(len(self.summary) * 0.8):]
+            else:
+                self.summary = "对话历史摘要:\n" + "\n".join(lines[2:])
 
         self.turns = recent
 

+ 16 - 3
Co-creation-projects/CC1227871-StockInsightAgent/memory.py

@@ -24,9 +24,22 @@ class StockMemory:
         return {"watchlist": {}, "history": [], "preferences": {}}
 
     def _save(self):
-        os.makedirs(os.path.dirname(self.path), exist_ok=True)
-        with open(self.path, "w", encoding="utf-8") as f:
-            json.dump(self.data, f, ensure_ascii=False, indent=2)
+        import tempfile
+        import threading
+
+        if not hasattr(self, '_lock'):
+            self._lock = threading.Lock()
+
+        with self._lock:
+            os.makedirs(os.path.dirname(self.path), exist_ok=True)
+            fd, temp_path = tempfile.mkstemp(dir=os.path.dirname(self.path))
+            try:
+                with os.fdopen(fd, "w", encoding="utf-8") as f:
+                    json.dump(self.data, f, ensure_ascii=False, indent=2)
+                os.replace(temp_path, self.path)
+            except Exception:
+                os.remove(temp_path)
+                raise
 
     # ===== 关注列表 =====
     def add_watchlist(self, code: str, name: str = "", notes: str = "") -> str:

+ 2 - 2
Co-creation-projects/CC1227871-StockInsightAgent/rag.py

@@ -124,8 +124,8 @@ class InvestmentKnowledgeBase:
 
         if not scores:
             # 回退到关键词匹配
-            for i, doc in enumerate(all_docs):
-                if any(kw in doc for kw in query):
+            for i, (doc, doc_tokens) in enumerate(zip(all_docs, tokenized_docs)):
+                if any(kw in doc_tokens for kw in tokenized_query):
                     scores.append((0.5, i))
             scores.sort(key=lambda x: x[0], reverse=True)
 

+ 22 - 21
Co-creation-projects/CC1227871-StockInsightAgent/test_tools.py

@@ -1,29 +1,30 @@
 """快速验证:单独测试每个工具函数"""
 from tools import get_realtime_quote, get_historical_data, get_financial_data, calc_indicators, get_news
 
-print("=" * 60)
-print("测试 1: 实时行情")
-print("=" * 60)
-print(get_realtime_quote("600519"))
+if __name__ == "__main__":
+    print("=" * 60)
+    print("测试 1: 实时行情")
+    print("=" * 60)
+    print(get_realtime_quote("600519"))
 
-print("\n" + "=" * 60)
-print("测试 2: 历史K线")
-print("=" * 60)
-print(get_historical_data("600519|daily|10"))
+    print("\n" + "=" * 60)
+    print("测试 2: 历史K线")
+    print("=" * 60)
+    print(get_historical_data("600519|daily|10"))
 
-print("\n" + "=" * 60)
-print("测试 3: 技术指标")
-print("=" * 60)
-print(calc_indicators("600519|daily|60"))
+    print("\n" + "=" * 60)
+    print("测试 3: 技术指标")
+    print("=" * 60)
+    print(calc_indicators("600519|daily|60"))
 
-print("\n" + "=" * 60)
-print("测试 4: 财务数据")
-print("=" * 60)
-print(get_financial_data("600519"))
+    print("\n" + "=" * 60)
+    print("测试 4: 财务数据")
+    print("=" * 60)
+    print(get_financial_data("600519"))
 
-print("\n" + "=" * 60)
-print("测试 5: 新闻")
-print("=" * 60)
-print(get_news("600519"))
+    print("\n" + "=" * 60)
+    print("测试 5: 新闻")
+    print("=" * 60)
+    print(get_news("600519"))
 
-print("\n全部工具测试完成!")
+    print("\n全部工具测试完成!")

+ 70 - 12
Co-creation-projects/CC1227871-StockInsightAgent/tools.py

@@ -43,6 +43,24 @@ def _resolve_symbol(query: str) -> str:
     query = query.strip()
     if query.isdigit() and len(query) == 6:
         return query
+
+    # 尝试使用 akshare stock_info_a_code_name 映射
+    try:
+        import akshare as ak
+        stock_info = ak.stock_info_a_code_name()
+
+        # 匹配名称
+        match = stock_info[stock_info["name"] == query]
+        if not match.empty:
+            return match["code"].values[0]
+
+        # 模糊匹配名称
+        fuzzy_match = stock_info[stock_info["name"].str.contains(query, na=False)]
+        if not fuzzy_match.empty:
+            return fuzzy_match["code"].values[0]
+    except Exception:
+        pass
+
     # 尝试通过新闻接口反查(间接方式)
     try:
         time.sleep(1)
@@ -76,7 +94,7 @@ def get_realtime_quote(query: str) -> str:
     数据源: 东方财富个股信息 + Sina 日线最新一条。
     """
     print(f"  [查询实时行情] {query}")
-    symbol = query.strip()
+    symbol = _resolve_symbol(query)
 
     # 使用 Sina 日线获取最新价格
     try:
@@ -86,8 +104,19 @@ def get_realtime_quote(query: str) -> str:
                          start_date=(datetime.now() - timedelta(days=10)).strftime("%Y%m%d"),
                          end_date=datetime.now().strftime("%Y%m%d"),
                          adjust="qfq")
+        if isinstance(df, str) or df is None or df.empty:
+            df = _safe_fetch(ak.stock_zh_a_hist, symbol=symbol, period="daily", start_date=(datetime.now() - timedelta(days=10)).strftime("%Y%m%d"), end_date=datetime.now().strftime("%Y%m%d"), adjust="qfq")
+
         if isinstance(df, str) or df is None or df.empty:
             return f"未找到 {symbol} 的行情数据"
+
+        if df is not None and not df.empty:
+            # 统一列名为英文以适配下游逻辑
+            rename_map = {
+                "日期": "date", "开盘": "open", "收盘": "close",
+                "最高": "high", "最低": "low", "成交量": "volume", "成交额": "amount"
+            }
+            df = df.rename(columns=rename_map)
     except Exception as e:
         return f"获取行情失败: {e}"
 
@@ -123,14 +152,14 @@ def get_realtime_quote(query: str) -> str:
 def get_historical_data(query: str) -> str:
     """
     获取历史K线数据。输入格式: "symbol|period|days"
-    period: daily(日), days: 最近多少天(默认60)
+    period: daily/weekly/monthly(日/周/月), days: 最近多少个周期(默认60)
     示例: "600519|daily|30"
     数据源: Sina
     """
     print(f"  [查询历史数据] {query}")
 
     parts = query.strip().split("|")
-    symbol = parts[0].strip()
+    symbol = _resolve_symbol(parts[0].strip())
     period = parts[1].strip() if len(parts) > 1 else "daily"
     try:
         days = int(parts[2]) if len(parts) > 2 else 60
@@ -138,11 +167,17 @@ def get_historical_data(query: str) -> str:
         days = 60
 
     end = datetime.now().strftime("%Y%m%d")
-    start = (datetime.now() - timedelta(days=days * 2)).strftime("%Y%m%d")
+    start = (datetime.now() - timedelta(days=days * 30)).strftime("%Y%m%d") if period != "daily" else (datetime.now() - timedelta(days=days * 2)).strftime("%Y%m%d")
 
     try:
         sina_code = _to_sina_code(symbol)
-        hist = _safe_fetch(ak.stock_zh_a_daily,
+        period_map = {"daily": "daily", "weekly": "weekly", "monthly": "monthly"}
+        ak_period = period_map.get(period, "daily")
+        hist = _safe_fetch(ak.stock_zh_a_hist,
+                           symbol=symbol, period=ak_period, start_date=start,
+                           end_date=end, adjust="qfq")
+        if isinstance(hist, str) or hist is None or hist.empty:
+            hist = _safe_fetch(ak.stock_zh_a_daily,
                            symbol=sina_code, start_date=start,
                            end_date=end, adjust="qfq")
         if isinstance(hist, str):
@@ -157,6 +192,13 @@ def get_historical_data(query: str) -> str:
                 "date": "date", "open": "open", "close": "close",
                 "high": "high", "low": "low", "amount": "volume"
             })
+        elif hist is not None and not hist.empty:
+            # 统一列名为英文以适配下游逻辑
+            rename_map = {
+                "日期": "date", "开盘": "open", "收盘": "close",
+                "最高": "high", "最低": "low", "成交量": "volume", "成交额": "amount"
+            }
+            hist = hist.rename(columns=rename_map)
     except Exception as e:
         return f"获取历史数据失败: {e}"
 
@@ -237,16 +279,16 @@ def get_financial_data(symbol: str) -> str:
         try:
             if unit == "元" and abs(float(val)) > 1e8:
                 val_str = f"{float(val)/1e8:.2f}亿"
-                if prev_val and not pd.isna(prev_val) and abs(float(prev_val)) > 1e8:
+                if prev_val is not None and not pd.isna(prev_val) and abs(float(prev_val)) > 1e8:
                     prev_str = f"{float(prev_val)/1e8:.2f}亿"
                 else:
                     prev_str = None
             elif unit == "%":
                 val_str = f"{float(val):.2f}%"
-                prev_str = f"{float(prev_val):.2f}%" if prev_val and not pd.isna(prev_val) else None
+                prev_str = f"{float(prev_val):.2f}%" if prev_val is not None and not pd.isna(prev_val) else None
             else:
                 val_str = f"{float(val):.4f}"
-                prev_str = f"{float(prev_val):.4f}" if prev_val and not pd.isna(prev_val) else None
+                prev_str = f"{float(prev_val):.4f}" if prev_val is not None and not pd.isna(prev_val) else None
         except (ValueError, TypeError):
             val_str = str(val)
             prev_str = str(prev_val) if prev_val is not None else None
@@ -286,8 +328,20 @@ def calc_indicators(query: str) -> str:
         df = _safe_fetch(ak.stock_zh_a_daily,
                          symbol=sina_code, start_date=start,
                          end_date=end, adjust="qfq")
+        if isinstance(df, str) or df is None or df.empty:
+            # 尝试新版 API fallback
+            df = _safe_fetch(ak.stock_zh_a_hist, symbol=symbol, period="daily", start_date=start, end_date=end, adjust="qfq")
+
         if isinstance(df, str) or df is None or df.empty:
             return f"未找到 {symbol} 的数据"
+
+        if df is not None and not df.empty:
+            # 统一列名为英文以适配下游逻辑
+            rename_map = {
+                "日期": "date", "开盘": "open", "收盘": "close",
+                "最高": "high", "最低": "low", "成交量": "volume", "成交额": "amount"
+            }
+            df = df.rename(columns=rename_map)
     except Exception as e:
         return f"获取数据失败: {e}"
 
@@ -328,10 +382,14 @@ def calc_indicators(query: str) -> str:
     bar_color = "红柱" if macd_bar.iloc[-1] > 0 else "绿柱"
     lines.append(f"  MACD柱: {macd_bar.iloc[-1]:.3f}  ({bar_color})")
 
-    if dif.iloc[-1] > dea.iloc[-1] and dif.iloc[-2] <= dea.iloc[-2]:
-        lines.append("  [!] 信号: 金叉(买入信号)")
-    elif dif.iloc[-1] < dea.iloc[-1] and dif.iloc[-2] >= dea.iloc[-2]:
-        lines.append("  [!] 信号: 死叉(卖出信号)")
+    if len(dif) >= 2:
+        if dif.iloc[-1] > dea.iloc[-1] and dif.iloc[-2] <= dea.iloc[-2]:
+            lines.append("  [!] 信号: 金叉(买入信号)")
+        elif dif.iloc[-1] < dea.iloc[-1] and dif.iloc[-2] >= dea.iloc[-2]:
+            lines.append("  [!] 信号: 死叉(卖出信号)")
+        else:
+            trend = "多头" if dif.iloc[-1] > dea.iloc[-1] else "空头"
+            lines.append(f"  趋势: {trend}持续")
     else:
         trend = "多头" if dif.iloc[-1] > dea.iloc[-1] else "空头"
         lines.append(f"  趋势: {trend}持续")