from typing import Tuple, Optional, Generator from sqlalchemy import create_engine, QueuePool from sqlalchemy.orm import sessionmaker, Session, scoped_session from app.core.config import settings # 数据库引擎 Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db", pool_pre_ping=True, echo=False, poolclass=QueuePool, pool_size=1024, pool_recycle=3600, pool_timeout=180, max_overflow=10, connect_args={"timeout": 60}) # 会话工厂 SessionFactory = sessionmaker(bind=Engine) # 多线程全局使用的数据库会话 ScopedSession = scoped_session(SessionFactory) def get_db() -> Generator: """ 获取数据库会话,用于WEB请求 :return: Session """ db = None try: db = SessionFactory() yield db finally: if db: db.close() def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]: """ 从参数中获取数据库Session对象 """ db = None if args: for arg in args: if isinstance(arg, Session): db = arg break if kwargs: for key, value in kwargs.items(): if isinstance(value, Session): db = value break return db def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]: """ 更新参数中的数据库Session对象,关键字传参时更新db的值,否则更新第1或第2个参数 """ if kwargs and 'db' in kwargs: kwargs['db'] = db elif args: if args[0] is None: args = (db, *args[1:]) else: args = (args[0], db, *args[2:]) return args, kwargs def db_update(func): """ 数据库更新类操作装饰器,第一个参数必须是数据库会话或存在db参数 """ def wrapper(*args, **kwargs): # 是否关闭数据库会话 _close_db = False # 从参数中获取数据库会话 db = get_args_db(args, kwargs) if not db: # 如果没有获取到数据库会话,创建一个 db = ScopedSession() # 标记需要关闭数据库会话 _close_db = True # 更新参数中的数据库会话 args, kwargs = update_args_db(args, kwargs, db) try: # 执行函数 result = func(*args, **kwargs) # 提交事务 db.commit() except Exception as err: # 回滚事务 db.rollback() raise err finally: # 关闭数据库会话 if _close_db: db.close() return result return wrapper def db_query(func): """ 数据库查询操作装饰器,第一个参数必须是数据库会话或存在db参数 注意:db.query列表数据时,需要转换为list返回 """ def wrapper(*args, **kwargs): # 是否关闭数据库会话 _close_db = False # 从参数中获取数据库会话 db = get_args_db(args, kwargs) if not db: # 如果没有获取到数据库会话,创建一个 db = ScopedSession() # 标记需要关闭数据库会话 _close_db = True # 更新参数中的数据库会话 args, kwargs = update_args_db(args, kwargs, db) try: # 执行函数 result = func(*args, **kwargs) except Exception as err: raise err finally: # 关闭数据库会话 if _close_db: db.close() return result return wrapper class DbOper: """ 数据库操作基类 """ _db: Session = None def __init__(self, db: Session = None): self._db = db