fix db session

This commit is contained in:
jxxghp
2023-10-19 16:58:38 +08:00
parent 21908bdc6f
commit a911bab7b0
11 changed files with 258 additions and 114 deletions

View File

@ -1,13 +1,10 @@
import threading
from typing import Tuple
from sqlalchemy import create_engine, QueuePool
from sqlalchemy.orm import sessionmaker, Session, scoped_session
from app.core.config import settings
# 数据库锁
DBLock = threading.Lock()
# 数据库引擎
Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db",
pool_pre_ping=True,
@ -27,7 +24,7 @@ ScopedSession = scoped_session(SessionFactory)
def get_db():
"""
获取数据库会话
获取数据库会话用于WEB请求
:return: Session
"""
db = None
@ -39,6 +36,105 @@ def get_db():
db.close()
def get_args_db(args: tuple, kwargs: dict):
"""
从参数中获取数据库Session对象
"""
db = None
if args:
for arg in args:
if isinstance(arg, Session):
db = arg
break
if kwargs:
for key, value in kwargs.items():
if isinstance(value, Session):
db = value
break
return db
def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
"""
更新参数中的数据库Session对象关键字传参时更新db的值否则更新第1或第2个参数
"""
if kwargs:
kwargs['db'] = db
elif args:
if args[0] is None:
args = (db, *args[1:])
else:
args = (args[0], db, *args[2:])
return args, kwargs
def db_update(func):
"""
数据库更新类操作装饰器第一个参数必须是数据库会话或存在db参数
"""
def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取数据库会话
db = get_args_db(args, kwargs)
if not db:
# 如果没有获取到数据库会话,创建一个
db = ScopedSession()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db)
try:
# 执行函数
result = func(*args, **kwargs)
# 提交事务
db.commit()
except Exception as err:
# 回滚事务
db.rollback()
raise err
finally:
# 关闭数据库会话
if _close_db:
db.close()
return result
return wrapper
def db_query(func):
"""
数据库查询操作装饰器第一个参数必须是数据库会话或存在db参数
注意db.query列表数据时需要转换为list返回
"""
def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取数据库会话
db = get_args_db(args, kwargs)
if not db:
# 如果没有获取到数据库会话,创建一个
db = ScopedSession()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db)
try:
# 执行函数
result = func(*args, **kwargs)
except Exception as err:
raise err
finally:
# 关闭数据库会话
if _close_db:
db.close()
return result
return wrapper
class DbOper:
"""
数据库操作基类
@ -46,8 +142,4 @@ class DbOper:
_db: Session = None
def __init__(self, db: Session = None):
if db:
self._db = db
else:
with DBLock:
self._db = ScopedSession()
self._db = db