146 lines
3.8 KiB
Python
146 lines
3.8 KiB
Python
from typing import Tuple
|
||
|
||
from sqlalchemy import create_engine, QueuePool
|
||
from sqlalchemy.orm import sessionmaker, Session, scoped_session
|
||
|
||
from app.core.config import settings
|
||
|
||
# 数据库引擎
|
||
Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db",
|
||
pool_pre_ping=True,
|
||
echo=False,
|
||
poolclass=QueuePool,
|
||
pool_size=1024,
|
||
pool_recycle=3600,
|
||
pool_timeout=180,
|
||
max_overflow=10,
|
||
connect_args={"timeout": 60})
|
||
# 会话工厂
|
||
SessionFactory = sessionmaker(bind=Engine)
|
||
|
||
# 多线程全局使用的数据库会话
|
||
ScopedSession = scoped_session(SessionFactory)
|
||
|
||
|
||
def get_db():
|
||
"""
|
||
获取数据库会话,用于WEB请求
|
||
:return: Session
|
||
"""
|
||
db = None
|
||
try:
|
||
db = SessionFactory()
|
||
yield db
|
||
finally:
|
||
if db:
|
||
db.close()
|
||
|
||
|
||
def get_args_db(args: tuple, kwargs: dict):
|
||
"""
|
||
从参数中获取数据库Session对象
|
||
"""
|
||
db = None
|
||
if args:
|
||
for arg in args:
|
||
if isinstance(arg, Session):
|
||
db = arg
|
||
break
|
||
if kwargs:
|
||
for key, value in kwargs.items():
|
||
if isinstance(value, Session):
|
||
db = value
|
||
break
|
||
return db
|
||
|
||
|
||
def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
|
||
"""
|
||
更新参数中的数据库Session对象,关键字传参时更新db的值,否则更新第1或第2个参数
|
||
"""
|
||
if kwargs:
|
||
kwargs['db'] = db
|
||
elif args:
|
||
if args[0] is None:
|
||
args = (db, *args[1:])
|
||
else:
|
||
args = (args[0], db, *args[2:])
|
||
return args, kwargs
|
||
|
||
|
||
def db_update(func):
|
||
"""
|
||
数据库更新类操作装饰器,第一个参数必须是数据库会话或存在db参数
|
||
"""
|
||
|
||
def wrapper(*args, **kwargs):
|
||
# 是否关闭数据库会话
|
||
_close_db = False
|
||
# 从参数中获取数据库会话
|
||
db = get_args_db(args, kwargs)
|
||
if not db:
|
||
# 如果没有获取到数据库会话,创建一个
|
||
db = ScopedSession()
|
||
# 标记需要关闭数据库会话
|
||
_close_db = True
|
||
# 更新参数中的数据库会话
|
||
args, kwargs = update_args_db(args, kwargs, db)
|
||
try:
|
||
# 执行函数
|
||
result = func(*args, **kwargs)
|
||
# 提交事务
|
||
db.commit()
|
||
except Exception as err:
|
||
# 回滚事务
|
||
db.rollback()
|
||
raise err
|
||
finally:
|
||
# 关闭数据库会话
|
||
if _close_db:
|
||
db.close()
|
||
return result
|
||
|
||
return wrapper
|
||
|
||
|
||
def db_query(func):
|
||
"""
|
||
数据库查询操作装饰器,第一个参数必须是数据库会话或存在db参数
|
||
注意:db.query列表数据时,需要转换为list返回
|
||
"""
|
||
|
||
def wrapper(*args, **kwargs):
|
||
# 是否关闭数据库会话
|
||
_close_db = False
|
||
# 从参数中获取数据库会话
|
||
db = get_args_db(args, kwargs)
|
||
if not db:
|
||
# 如果没有获取到数据库会话,创建一个
|
||
db = ScopedSession()
|
||
# 标记需要关闭数据库会话
|
||
_close_db = True
|
||
# 更新参数中的数据库会话
|
||
args, kwargs = update_args_db(args, kwargs, db)
|
||
try:
|
||
# 执行函数
|
||
result = func(*args, **kwargs)
|
||
except Exception as err:
|
||
raise err
|
||
finally:
|
||
# 关闭数据库会话
|
||
if _close_db:
|
||
db.close()
|
||
return result
|
||
|
||
return wrapper
|
||
|
||
|
||
class DbOper:
|
||
"""
|
||
数据库操作基类
|
||
"""
|
||
_db: Session = None
|
||
|
||
def __init__(self, db: Session = None):
|
||
self._db = db
|