user_service.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. """
  2. 用户服务
  3. """
  4. from typing import Optional, List
  5. from sqlalchemy.orm import Session
  6. from passlib.context import CryptContext
  7. from ..core.database import get_db
  8. from ..models.user import UserDB, User, UserCreate, UserUpdate
  9. from ..core.exceptions import UserNotFoundError, UserAlreadyExistsError
  10. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  11. class UserService:
  12. """用户服务类"""
  13. def __init__(self, db: Session):
  14. self.db = db
  15. @staticmethod
  16. def verify_password(plain_password: str, hashed_password: str) -> bool:
  17. """验证密码"""
  18. return pwd_context.verify(plain_password, hashed_password)
  19. @staticmethod
  20. def get_password_hash(password: str) -> str:
  21. """获取密码哈希"""
  22. return pwd_context.hash(password)
  23. def get_user_by_id(self, user_id: int) -> Optional[User]:
  24. """根据ID获取用户"""
  25. user_db = self.db.query(UserDB).filter(UserDB.id == user_id).first()
  26. if not user_db:
  27. raise UserNotFoundError(f"User with id {user_id} not found")
  28. return User.from_orm(user_db)
  29. def get_user_by_email(self, email: str) -> Optional[User]:
  30. """根据邮箱获取用户"""
  31. user_db = self.db.query(UserDB).filter(UserDB.email == email).first()
  32. if not user_db:
  33. raise UserNotFoundError(f"User with email {email} not found")
  34. return User.from_orm(user_db)
  35. def get_user_by_username(self, username: str) -> Optional[User]:
  36. """根据用户名获取用户"""
  37. user_db = self.db.query(UserDB).filter(UserDB.username == username).first()
  38. if not user_db:
  39. raise UserNotFoundError(f"User with username {username} not found")
  40. return User.from_orm(user_db)
  41. def create_user(self, user_create: UserCreate) -> User:
  42. """创建用户"""
  43. # 检查邮箱是否已存在
  44. if self.db.query(UserDB).filter(UserDB.email == user_create.email).first():
  45. raise UserAlreadyExistsError(f"Email {user_create.email} already registered")
  46. # 检查用户名是否已存在
  47. if self.db.query(UserDB).filter(UserDB.username == user_create.username).first():
  48. raise UserAlreadyExistsError(f"Username {user_create.username} already taken")
  49. # 创建新用户
  50. hashed_password = self.get_password_hash(user_create.password)
  51. user_db = UserDB(
  52. username=user_create.username,
  53. email=user_create.email,
  54. hashed_password=hashed_password,
  55. full_name=user_create.full_name,
  56. institution=user_create.institution,
  57. research_field=user_create.research_field
  58. )
  59. self.db.add(user_db)
  60. self.db.commit()
  61. self.db.refresh(user_db)
  62. return User.from_orm(user_db)
  63. def update_user(self, user_id: int, user_update: UserUpdate) -> User:
  64. """更新用户信息"""
  65. user_db = self.db.query(UserDB).filter(UserDB.id == user_id).first()
  66. if not user_db:
  67. raise UserNotFoundError(f"User with id {user_id} not found")
  68. # 更新字段
  69. update_data = user_update.dict(exclude_unset=True)
  70. for field, value in update_data.items():
  71. setattr(user_db, field, value)
  72. self.db.commit()
  73. self.db.refresh(user_db)
  74. return User.from_orm(user_db)
  75. def delete_user(self, user_id: int) -> bool:
  76. """删除用户"""
  77. user_db = self.db.query(UserDB).filter(UserDB.id == user_id).first()
  78. if not user_db:
  79. raise UserNotFoundError(f"User with id {user_id} not found")
  80. self.db.delete(user_db)
  81. self.db.commit()
  82. return True
  83. def authenticate_user(self, email: str, password: str) -> Optional[User]:
  84. """验证用户登录"""
  85. try:
  86. user = self.get_user_by_email(email)
  87. if self.verify_password(password, user.hashed_password):
  88. return user
  89. except UserNotFoundError:
  90. pass
  91. return None
  92. def change_password(self, user_id: int, current_password: str, new_password: str) -> bool:
  93. """修改密码"""
  94. user_db = self.db.query(UserDB).filter(UserDB.id == user_id).first()
  95. if not user_db:
  96. raise UserNotFoundError(f"User with id {user_id} not found")
  97. if not self.verify_password(current_password, user_db.hashed_password):
  98. return False
  99. user_db.hashed_password = self.get_password_hash(new_password)
  100. self.db.commit()
  101. return True