|
@@ -21,17 +21,39 @@ def get_agent():
|
|
|
|
|
|
|
|
def _run_with_capture(q: queue.Queue, agent, mode: str, msg: str):
|
|
def _run_with_capture(q: queue.Queue, agent, mode: str, msg: str):
|
|
|
import io, sys
|
|
import io, sys
|
|
|
|
|
+ import contextlib
|
|
|
buffer = io.StringIO()
|
|
buffer = io.StringIO()
|
|
|
|
|
|
|
|
class QueueWriter:
|
|
class QueueWriter:
|
|
|
|
|
+ def __init__(self, original_stdout):
|
|
|
|
|
+ self.original = original_stdout
|
|
|
|
|
+ # Ensure _local exists for threads
|
|
|
|
|
+ import threading
|
|
|
|
|
+ self._local = threading.local()
|
|
|
|
|
+
|
|
|
|
|
+ @property
|
|
|
|
|
+ def is_active(self):
|
|
|
|
|
+ return getattr(self._local, "active", False)
|
|
|
|
|
+
|
|
|
|
|
+ @is_active.setter
|
|
|
|
|
+ def is_active(self, value):
|
|
|
|
|
+ self._local.active = value
|
|
|
|
|
+
|
|
|
def write(self, s):
|
|
def write(self, s):
|
|
|
- buffer.write(s)
|
|
|
|
|
- q.put(s)
|
|
|
|
|
|
|
+ if self.is_active:
|
|
|
|
|
+ buffer.write(s)
|
|
|
|
|
+ q.put(s)
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.original.write(s)
|
|
|
|
|
+
|
|
|
def flush(self):
|
|
def flush(self):
|
|
|
- pass
|
|
|
|
|
|
|
+ if not self.is_active:
|
|
|
|
|
+ self.original.flush()
|
|
|
|
|
+
|
|
|
|
|
+ if not isinstance(sys.stdout, QueueWriter):
|
|
|
|
|
+ sys.stdout = QueueWriter(sys.stdout)
|
|
|
|
|
|
|
|
- old = sys.stdout
|
|
|
|
|
- sys.stdout = QueueWriter()
|
|
|
|
|
|
|
+ sys.stdout.is_active = True
|
|
|
try:
|
|
try:
|
|
|
if mode == "深度分析 (PlanSolve)":
|
|
if mode == "深度分析 (PlanSolve)":
|
|
|
result = agent.plan_solve(msg)
|
|
result = agent.plan_solve(msg)
|
|
@@ -42,14 +64,17 @@ def _run_with_capture(q: queue.Queue, agent, mode: str, msg: str):
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
result = f"分析出错: {e}"
|
|
result = f"分析出错: {e}"
|
|
|
finally:
|
|
finally:
|
|
|
- sys.stdout = old
|
|
|
|
|
|
|
+ sys.stdout.is_active = False
|
|
|
q.put(None)
|
|
q.put(None)
|
|
|
q.result = result or ""
|
|
q.result = result or ""
|
|
|
|
|
|
|
|
|
|
|
|
|
-def respond_stream(message: str, history: list, mode: str):
|
|
|
|
|
|
|
+def respond_stream(message: str, history: list, mode: str, agent=None):
|
|
|
|
|
+ if agent is None:
|
|
|
|
|
+ agent = get_agent()
|
|
|
|
|
+
|
|
|
if not message or not message.strip():
|
|
if not message or not message.strip():
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
msg = message.strip()
|
|
msg = message.strip()
|
|
@@ -67,7 +92,7 @@ def respond_stream(message: str, history: list, mode: str):
|
|
|
if msg in keys:
|
|
if msg in keys:
|
|
|
history.append({"role": "user", "content": msg})
|
|
history.append({"role": "user", "content": msg})
|
|
|
history.append({"role": "assistant", "content": handler()})
|
|
history.append({"role": "assistant", "content": handler()})
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
if msg.startswith("关注 "):
|
|
if msg.startswith("关注 "):
|
|
@@ -75,25 +100,25 @@ def respond_stream(message: str, history: list, mode: str):
|
|
|
c, n = parts[0], (parts[1] if len(parts) > 1 else "")
|
|
c, n = parts[0], (parts[1] if len(parts) > 1 else "")
|
|
|
history.append({"role": "user", "content": msg})
|
|
history.append({"role": "user", "content": msg})
|
|
|
history.append({"role": "assistant", "content": memory_add_watchlist(f"{c}|{n}")})
|
|
history.append({"role": "assistant", "content": memory_add_watchlist(f"{c}|{n}")})
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
if msg.startswith("移除 "):
|
|
if msg.startswith("移除 "):
|
|
|
history.append({"role": "user", "content": msg})
|
|
history.append({"role": "user", "content": msg})
|
|
|
history.append({"role": "assistant", "content": memory_remove_watchlist(msg[3:].strip())})
|
|
history.append({"role": "assistant", "content": memory_remove_watchlist(msg[3:].strip())})
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
if msg.startswith("历史 "):
|
|
if msg.startswith("历史 "):
|
|
|
history.append({"role": "user", "content": msg})
|
|
history.append({"role": "user", "content": msg})
|
|
|
history.append({"role": "assistant", "content": memory_get_history(msg[3:].strip())})
|
|
history.append({"role": "assistant", "content": memory_get_history(msg[3:].strip())})
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
if msg.startswith("导入 "):
|
|
if msg.startswith("导入 "):
|
|
|
history.append({"role": "user", "content": msg})
|
|
history.append({"role": "user", "content": msg})
|
|
|
history.append({"role": "assistant", "content": rag_import(msg[3:].strip()) + "\n" + rag_stats()})
|
|
history.append({"role": "assistant", "content": rag_import(msg[3:].strip()) + "\n" + rag_stats()})
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
# ── 流式分析 ──
|
|
# ── 流式分析 ──
|
|
@@ -101,7 +126,7 @@ def respond_stream(message: str, history: list, mode: str):
|
|
|
history.append({"role": "assistant", "content": "..."})
|
|
history.append({"role": "assistant", "content": "..."})
|
|
|
|
|
|
|
|
q = queue.Queue()
|
|
q = queue.Queue()
|
|
|
- t = threading.Thread(target=_run_with_capture, args=(q, get_agent(), mode, msg), daemon=True)
|
|
|
|
|
|
|
+ t = threading.Thread(target=_run_with_capture, args=(q, agent, mode, msg), daemon=True)
|
|
|
t.start()
|
|
t.start()
|
|
|
|
|
|
|
|
collected = []
|
|
collected = []
|
|
@@ -111,17 +136,17 @@ def respond_stream(message: str, history: list, mode: str):
|
|
|
except queue.Empty:
|
|
except queue.Empty:
|
|
|
if collected:
|
|
if collected:
|
|
|
history[-1]["content"] = "".join(collected)
|
|
history[-1]["content"] = "".join(collected)
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
continue
|
|
continue
|
|
|
if chunk is None:
|
|
if chunk is None:
|
|
|
break
|
|
break
|
|
|
collected.append(chunk)
|
|
collected.append(chunk)
|
|
|
history[-1]["content"] = "".join(collected)
|
|
history[-1]["content"] = "".join(collected)
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
|
|
|
|
|
final = getattr(q, 'result', '') or "".join(collected)
|
|
final = getattr(q, 'result', '') or "".join(collected)
|
|
|
history[-1]["content"] = str(final)
|
|
history[-1]["content"] = str(final)
|
|
|
- yield history, ""
|
|
|
|
|
|
|
+ yield history, "", agent
|
|
|
|
|
|
|
|
|
|
|
|
|
HELP_TEXT = """## 使用指南
|
|
HELP_TEXT = """## 使用指南
|
|
@@ -390,6 +415,8 @@ with gr.Blocks(title="StockInsightAgent") as app:
|
|
|
|
|
|
|
|
# ── 主区域 ──
|
|
# ── 主区域 ──
|
|
|
with gr.Column(scale=4):
|
|
with gr.Column(scale=4):
|
|
|
|
|
+ agent_state = gr.State(None)
|
|
|
|
|
+
|
|
|
chatbot = gr.Chatbot(
|
|
chatbot = gr.Chatbot(
|
|
|
label="",
|
|
label="",
|
|
|
height=520,
|
|
height=520,
|
|
@@ -413,24 +440,24 @@ with gr.Blocks(title="StockInsightAgent") as app:
|
|
|
# ── 事件绑定 ──
|
|
# ── 事件绑定 ──
|
|
|
msg_input.submit(
|
|
msg_input.submit(
|
|
|
fn=respond_stream,
|
|
fn=respond_stream,
|
|
|
- inputs=[msg_input, chatbot, mode_radio],
|
|
|
|
|
- outputs=[chatbot, msg_input],
|
|
|
|
|
|
|
+ inputs=[msg_input, chatbot, mode_radio, agent_state],
|
|
|
|
|
+ outputs=[chatbot, msg_input, agent_state],
|
|
|
)
|
|
)
|
|
|
submit_btn.click(
|
|
submit_btn.click(
|
|
|
fn=respond_stream,
|
|
fn=respond_stream,
|
|
|
- inputs=[msg_input, chatbot, mode_radio],
|
|
|
|
|
- outputs=[chatbot, msg_input],
|
|
|
|
|
|
|
+ inputs=[msg_input, chatbot, mode_radio, agent_state],
|
|
|
|
|
+ outputs=[chatbot, msg_input, agent_state],
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- def quick_action(action, history):
|
|
|
|
|
- for result in respond_stream(action, history, "快速分析 (ReAct)"):
|
|
|
|
|
|
|
+ def quick_action(action, history, agent):
|
|
|
|
|
+ for result in respond_stream(action, history, "快速分析 (ReAct)", agent):
|
|
|
pass
|
|
pass
|
|
|
- return result[0]
|
|
|
|
|
|
|
+ return result[0], result[2]
|
|
|
|
|
|
|
|
- btn_watchlist.click(lambda h: quick_action("列表", h), [chatbot], [chatbot])
|
|
|
|
|
- btn_history.click(lambda h: quick_action("历史", h), [chatbot], [chatbot])
|
|
|
|
|
- btn_kb.click(lambda h: quick_action("知识库", h), [chatbot], [chatbot])
|
|
|
|
|
- btn_prefs.click(lambda h: quick_action("偏好", h), [chatbot], [chatbot])
|
|
|
|
|
|
|
+ btn_watchlist.click(lambda h, a: quick_action("列表", h, a), [chatbot, agent_state], [chatbot, agent_state])
|
|
|
|
|
+ btn_history.click(lambda h, a: quick_action("历史", h, a), [chatbot, agent_state], [chatbot, agent_state])
|
|
|
|
|
+ btn_kb.click(lambda h, a: quick_action("知识库", h, a), [chatbot, agent_state], [chatbot, agent_state])
|
|
|
|
|
+ btn_prefs.click(lambda h, a: quick_action("偏好", h, a), [chatbot, agent_state], [chatbot, agent_state])
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
app.launch(
|
|
app.launch(
|