preference_service.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. """
  2. 智能股票分析助手 — 用户偏好服务层
  3. 提供偏好的CRUD操作,以及向智能体层输出偏好上下文的方法。
  4. """
  5. import json
  6. from typing import Optional
  7. from sqlalchemy import select
  8. from sqlalchemy.ext.asyncio import AsyncSession
  9. from app.models.preference import UserPreference
  10. # =========================================================================
  11. # 默认偏好(当用户未配置或无登录时使用)
  12. # =========================================================================
  13. DEFAULT_PREFERENCE = {
  14. "risk_tolerance": "moderate",
  15. "investment_style": "blend",
  16. "preferred_sectors": [],
  17. "excluded_sectors": [],
  18. "investment_horizon": "medium",
  19. "target_return_rate": 10.0,
  20. "max_position_ratio": 30.0,
  21. "max_drawdown_limit": -15.0,
  22. "notification_enabled": True,
  23. "notification_channels": ["push"],
  24. "market_alert_threshold": 5.0,
  25. "language": "zh",
  26. "theme": "auto",
  27. "default_view": "dashboard",
  28. }
  29. # =========================================================================
  30. # CRUD 操作
  31. # =========================================================================
  32. async def get_preference(db: AsyncSession, user_id: str = "default") -> dict:
  33. """获取用户偏好,不存在时返回默认值"""
  34. result = await db.execute(
  35. select(UserPreference).where(UserPreference.user_id == user_id)
  36. )
  37. pref = result.scalar_one_or_none()
  38. if pref is None:
  39. return {**DEFAULT_PREFERENCE, "user_id": user_id}
  40. return pref.to_dict()
  41. async def get_or_create_preference(db: AsyncSession, user_id: str = "default") -> UserPreference:
  42. """获取或创建用户偏好记录,返回ORM对象"""
  43. result = await db.execute(
  44. select(UserPreference).where(UserPreference.user_id == user_id)
  45. )
  46. pref = result.scalar_one_or_none()
  47. if pref is None:
  48. pref = UserPreference.create_default(user_id)
  49. db.add(pref)
  50. await db.commit()
  51. await db.refresh(pref)
  52. return pref
  53. async def update_preference(db: AsyncSession, user_id: str, data: dict) -> dict:
  54. """更新用户偏好,支持部分更新"""
  55. pref = await get_or_create_preference(db, user_id)
  56. # 允许更新的字段白名单(防止注入未定义的字段)
  57. ALLOWED_FIELDS = {
  58. "risk_tolerance", "investment_style", "investment_horizon",
  59. "target_return_rate", "max_position_ratio", "max_drawdown_limit",
  60. "notification_enabled", "market_alert_threshold",
  61. "language", "theme", "default_view",
  62. # 以下是JSON字段
  63. "preferred_sectors", "excluded_sectors", "notification_channels",
  64. }
  65. for key, value in data.items():
  66. if key not in ALLOWED_FIELDS:
  67. continue
  68. # JSON数组字段序列化
  69. if key in ("preferred_sectors", "excluded_sectors", "notification_channels"):
  70. if value is not None:
  71. setattr(pref, key, json.dumps(value, ensure_ascii=False))
  72. # 布尔值字段
  73. elif key == "notification_enabled":
  74. setattr(pref, key, bool(value))
  75. # 数值字段
  76. elif key in ("target_return_rate", "max_position_ratio", "max_drawdown_limit", "market_alert_threshold"):
  77. setattr(pref, key, float(value))
  78. else:
  79. setattr(pref, key, value)
  80. await db.commit()
  81. await db.refresh(pref)
  82. return pref.to_dict()
  83. # =========================================================================
  84. # 智能体注入方法
  85. # =========================================================================
  86. async def get_preference_context(db: AsyncSession, user_id: str = "default") -> str:
  87. """生成偏好上下文文本,供智能体层注入分析流程
  88. 返回格式化的中文描述,可直接作为Agent系统提示词的一部分。
  89. """
  90. pref = await get_preference(db, user_id)
  91. if pref is None:
  92. pref = DEFAULT_PREFERENCE
  93. risk_map = {
  94. "conservative": "保守型——侧重低估值、高股息、蓝筹股,回避高风险标的",
  95. "moderate": "稳健型——均衡配置,兼顾成长与价值",
  96. "aggressive": "激进型——侧重高成长、高波动标的,接受较大回撤",
  97. }
  98. style_map = {
  99. "value": "价值投资——偏好低PE、低PB、高股息率标的",
  100. "growth": "成长投资——偏好高营收增速、高利润增速标的",
  101. "momentum": "动量投资——偏好强势股、趋势跟踪",
  102. "dividend": "股息投资——偏好高分红率标的",
  103. "blend": "混合风格——综合运用多种投资策略",
  104. }
  105. horizon_map = {
  106. "short": "短期(<1年)",
  107. "medium": "中期(1-3年)",
  108. "long": "长期(>3年)",
  109. }
  110. preferred = json.loads(pref.get("preferred_sectors", "[]")) if isinstance(pref.get("preferred_sectors"), str) else pref.get("preferred_sectors", [])
  111. excluded = json.loads(pref.get("excluded_sectors", "[]")) if isinstance(pref.get("excluded_sectors"), str) else pref.get("excluded_sectors", [])
  112. context_parts = [
  113. "## 用户投资偏好(请据此调整分析和建议)",
  114. f"- 风险承受度: {risk_map.get(pref['risk_tolerance'], pref['risk_tolerance'])}",
  115. f"- 投资风格: {style_map.get(pref['investment_style'], pref['investment_style'])}",
  116. f"- 投资期限: {horizon_map.get(pref['investment_horizon'], pref['investment_horizon'])}",
  117. f"- 目标年化收益率: {pref['target_return_rate']}%",
  118. f"- 单票最大仓位: {pref['max_position_ratio']}%",
  119. f"- 最大回撤预警线: {pref['max_drawdown_limit']}%",
  120. ]
  121. if preferred:
  122. context_parts.append(f"- 偏好行业: {', '.join(preferred)}")
  123. if excluded:
  124. context_parts.append(f"- 排除行业: {', '.join(excluded)}")
  125. return "\n".join(context_parts)
  126. async def get_profile_summary(db: AsyncSession, user_id: str = "default") -> dict:
  127. """获取用户投资画像摘要(用于前端偏好摘要展示)"""
  128. pref = await get_preference(db, user_id)
  129. if pref is None:
  130. pref = DEFAULT_PREFERENCE
  131. risk_labels = {"conservative": "保守型", "moderate": "稳健型", "aggressive": "激进型"}
  132. style_labels = {"value": "价值投资", "growth": "成长投资", "momentum": "动量投资", "dividend": "股息投资", "blend": "混合风格"}
  133. return {
  134. "user_id": pref["user_id"],
  135. "risk_label": risk_labels.get(pref["risk_tolerance"], pref["risk_tolerance"]),
  136. "style_label": style_labels.get(pref["investment_style"], pref["investment_style"]),
  137. "target_return": f"{pref['target_return_rate']}%",
  138. "max_drawdown": f"{pref['max_drawdown_limit']}%",
  139. "preferred_sectors_count": len(pref.get("preferred_sectors", []) if isinstance(pref.get("preferred_sectors"), list) else json.loads(pref.get("preferred_sectors", "[]"))),
  140. "is_configured": pref.get("id") is not None,
  141. }