Răsfoiți Sursa

fix:修复若干问题,增强agent整体稳定性

CC1227871 1 lună în urmă
părinte
comite
b62481ba00

+ 3 - 2
Co-creation-projects/CC1227871-StockInsightAgent/context_manager.py

@@ -67,8 +67,9 @@ class ContextManager:
             # Drop the earliest part of the summary string by splitting on lines
             lines = self.summary.split('\n')
             if len(lines) <= 2:
-                # If there are only a couple lines left, we must chop characters
-                self.summary = self.summary[int(len(self.summary) * 0.8):]
+                # If there are only a couple lines left, we must chop strings carefully or discard
+                self.summary = ""
+                break
             else:
                 self.summary = "对话历史摘要:\n" + "\n".join(lines[2:])
 

+ 32 - 29
Co-creation-projects/CC1227871-StockInsightAgent/framework_agent.py

@@ -115,35 +115,38 @@ class FrameworkStockAgent:
         base = HelloAgentsLLM()
         reasoning_entries = []  # 按序存储每轮 assistant 的 reasoning_content
 
-        adapter = base._adapter
-        adapter._client = adapter.create_client()
-        original_create = adapter._client.chat.completions.create
-
-        def patched_create(*args, **kwargs):
-            messages = kwargs.get("messages", [])
-            # 注入 reasoning_content:给最近一条没有它的 assistant 消息
-            missing_idx = 0
-            fixed_msgs = []
-            for m in messages:
-                m2 = dict(m)
-                if m2.get("role") == "assistant" and not m2.get("reasoning_content"):
-                    if missing_idx < len(reasoning_entries):
-                        m2["reasoning_content"] = reasoning_entries[missing_idx]
-                    missing_idx += 1
-                fixed_msgs.append(m2)
-            kwargs["messages"] = fixed_msgs
-            resp = original_create(*args, **kwargs)
-            # 保存新 reasoning_content
-            try:
-                msg = resp.choices[0].message
-                rc = getattr(msg, "reasoning_content", None)
-                if rc:
-                    reasoning_entries.append(rc)
-            except Exception:
-                pass
-            return resp
-
-        adapter._client.chat.completions.create = patched_create
+        try:
+            adapter = base._adapter
+            adapter._client = adapter.create_client()
+            original_create = adapter._client.chat.completions.create
+
+            def patched_create(*args, **kwargs):
+                messages = kwargs.get("messages", [])
+                # 注入 reasoning_content:给最近一条没有它的 assistant 消息
+                missing_idx = 0
+                fixed_msgs = []
+                for m in messages:
+                    m2 = dict(m)
+                    if m2.get("role") == "assistant" and not m2.get("reasoning_content"):
+                        if missing_idx < len(reasoning_entries):
+                            m2["reasoning_content"] = reasoning_entries[missing_idx]
+                        missing_idx += 1
+                    fixed_msgs.append(m2)
+                kwargs["messages"] = fixed_msgs
+                resp = original_create(*args, **kwargs)
+                # 保存新 reasoning_content
+                try:
+                    msg = resp.choices[0].message
+                    rc = getattr(msg, "reasoning_content", None)
+                    if rc:
+                        reasoning_entries.append(rc)
+                except Exception:
+                    pass
+                return resp
+
+            adapter._client.chat.completions.create = patched_create
+        except Exception as e:
+            print(f"Warning: _build_llm monkey patch failed, falling back to standard LLM: {e}")
         return base
 
     def _run_with_context(self, agent, question: str, mode: str):

+ 1 - 1
Co-creation-projects/CC1227871-StockInsightAgent/llm_client.py

@@ -54,4 +54,4 @@ class HelloAgentsLLM:
                 return clean
             except Exception as e2:
                 print(f"[ERR] 非流式也失败: {e2}")
-                return None
+                raise RuntimeError(f"LLM调用完全失败: {e2}")

+ 2 - 4
Co-creation-projects/CC1227871-StockInsightAgent/memory.py

@@ -11,8 +11,10 @@ class StockMemory:
     """股票分析记忆 — JSON 文件持久化"""
 
     def __init__(self, path: str = "memory/stock_memory.json"):
+        import threading
         self.path = path
         self.data = self._load()
+        self._lock = threading.Lock()
 
     def _load(self) -> dict:
         if os.path.exists(self.path):
@@ -25,10 +27,6 @@ class StockMemory:
 
     def _save(self):
         import tempfile
-        import threading
-
-        if not hasattr(self, '_lock'):
-            self._lock = threading.Lock()
 
         with self._lock:
             os.makedirs(os.path.dirname(self.path), exist_ok=True)

+ 6 - 2
Co-creation-projects/CC1227871-StockInsightAgent/rag.py

@@ -41,8 +41,12 @@ class InvestmentKnowledgeBase:
         if not os.path.exists(filepath):
             return f"文件不存在: {filepath}"
         try:
-            with open(filepath, "r", encoding="utf-8") as f:
-                text = f.read()
+            try:
+                with open(filepath, "r", encoding="utf-8") as f:
+                    text = f.read()
+            except UnicodeDecodeError:
+                with open(filepath, "r", encoding="gbk") as f:
+                    text = f.read()
         except Exception as e:
             return f"读取文件失败: {e}"
         title = os.path.basename(filepath)

+ 10 - 11
Co-creation-projects/CC1227871-StockInsightAgent/tools.py

@@ -77,13 +77,12 @@ def _safe_fetch(func, *args, **kwargs):
     import random
     for attempt in range(3):
         try:
-            time.sleep(2 + random.random())
             return func(*args, **kwargs)
         except Exception as e:
             if attempt < 2:
                 time.sleep(4 + random.random() * 2)
             else:
-                return f"数据获取失败: {e}"
+                return None
 
 
 # ==================== 工具函数 ====================
@@ -104,10 +103,10 @@ def get_realtime_quote(query: str) -> str:
                          start_date=(datetime.now() - timedelta(days=10)).strftime("%Y%m%d"),
                          end_date=datetime.now().strftime("%Y%m%d"),
                          adjust="qfq")
-        if isinstance(df, str) or df is None or df.empty:
+        if df is None or df.empty:
             df = _safe_fetch(ak.stock_zh_a_hist, symbol=symbol, period="daily", start_date=(datetime.now() - timedelta(days=10)).strftime("%Y%m%d"), end_date=datetime.now().strftime("%Y%m%d"), adjust="qfq")
 
-        if isinstance(df, str) or df is None or df.empty:
+        if df is None or df.empty:
             return f"未找到 {symbol} 的行情数据"
 
         if df is not None and not df.empty:
@@ -176,16 +175,16 @@ def get_historical_data(query: str) -> str:
         hist = _safe_fetch(ak.stock_zh_a_hist,
                            symbol=symbol, period=ak_period, start_date=start,
                            end_date=end, adjust="qfq")
-        if isinstance(hist, str) or hist is None or hist.empty:
+        if hist is None or hist.empty:
             hist = _safe_fetch(ak.stock_zh_a_daily,
                            symbol=sina_code, start_date=start,
                            end_date=end, adjust="qfq")
-        if isinstance(hist, str):
+        if hist is None or hist.empty:
             # 尝试 Tencent 源
             time.sleep(2)
             hist = ak.stock_zh_a_hist_tx(symbol=sina_code,
                                          start_date=start, end_date=end)
-            if isinstance(hist, str) or hist is None or hist.empty:
+            if hist is None or hist.empty:
                 return f"未找到 {symbol} 的历史数据"
             # Tencent 列名映射
             hist = hist.rename(columns={
@@ -233,7 +232,7 @@ def get_financial_data(symbol: str) -> str:
 
     try:
         df = _safe_fetch(ak.stock_financial_abstract, symbol=symbol)
-        if isinstance(df, str) or df is None or df.empty:
+        if df is None or df.empty:
             return f"未找到 {symbol} 的财务数据"
     except Exception as e:
         return f"获取财务数据失败: {e}"
@@ -328,11 +327,11 @@ def calc_indicators(query: str) -> str:
         df = _safe_fetch(ak.stock_zh_a_daily,
                          symbol=sina_code, start_date=start,
                          end_date=end, adjust="qfq")
-        if isinstance(df, str) or df is None or df.empty:
+        if df is None or df.empty:
             # 尝试新版 API fallback
             df = _safe_fetch(ak.stock_zh_a_hist, symbol=symbol, period="daily", start_date=start, end_date=end, adjust="qfq")
 
-        if isinstance(df, str) or df is None or df.empty:
+        if df is None or df.empty:
             return f"未找到 {symbol} 的数据"
 
         if df is not None and not df.empty:
@@ -460,7 +459,7 @@ def get_news(symbol: str) -> str:
 
     try:
         news_df = _safe_fetch(ak.stock_news_em, symbol=symbol)
-        if isinstance(news_df, str) or news_df is None or news_df.empty:
+        if news_df is None or news_df.empty:
             return f"未找到 {symbol} 的相关新闻"
     except Exception as e:
         return f"获取新闻失败: {e}"