database.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. """
  2. 智能股票分析助手 — 数据库连接模块
  3. 使用SQLAlchemy + aiosqlite实现异步数据库访问。
  4. 数据库文件自动创建在项目data目录下。
  5. """
  6. from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
  7. from sqlalchemy.orm import DeclarativeBase
  8. from pathlib import Path
  9. import sys
  10. # 确保能导入配置模块
  11. _PROJECT_ROOT = Path(__file__).parent.parent.parent.parent
  12. sys.path.insert(0, str(_PROJECT_ROOT / "backend"))
  13. from app.config import settings
  14. # 将SQLite URL转换为异步版本(aiosqlite)
  15. def _build_async_url(url: str) -> str:
  16. """将 sqlite:/// 格式转为 sqlite+aiosqlite:/// 格式"""
  17. if url.startswith("sqlite:///"):
  18. return url.replace("sqlite:///", "sqlite+aiosqlite:///")
  19. return url
  20. # 确保数据目录存在
  21. settings.DATA_DIR.mkdir(parents=True, exist_ok=True)
  22. # 创建异步引擎
  23. engine = create_async_engine(
  24. _build_async_url(settings.DATABASE_URL),
  25. echo=False, # 开发时可设为True查看SQL日志
  26. )
  27. # 创建异步会话工厂
  28. async_session_factory = async_sessionmaker(
  29. engine,
  30. class_=AsyncSession,
  31. expire_on_commit=False,
  32. )
  33. # SQLAlchemy声明式基类
  34. class Base(DeclarativeBase):
  35. pass
  36. async def init_db():
  37. """初始化数据库,创建所有表"""
  38. async with engine.begin() as conn:
  39. await conn.run_sync(Base.metadata.create_all)
  40. async def get_db_session() -> AsyncSession:
  41. """获取数据库会话(FastAPI依赖注入用)"""
  42. async with async_session_factory() as session:
  43. try:
  44. yield session
  45. finally:
  46. await session.close()
  47. async def close_db():
  48. """关闭数据库连接"""
  49. await engine.dispose()