fix db session

This commit is contained in:
jxxghp 2023-10-18 18:30:09 +08:00
parent 68e16d18fe
commit 2e4536edb6
10 changed files with 48 additions and 48 deletions

View File

@ -1,5 +1,3 @@
import threading
from sqlalchemy import create_engine, QueuePool from sqlalchemy import create_engine, QueuePool
from sqlalchemy.orm import sessionmaker, Session, scoped_session from sqlalchemy.orm import sessionmaker, Session, scoped_session
@ -16,14 +14,11 @@ Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db",
max_overflow=10, max_overflow=10,
connect_args={"timeout": 60}) connect_args={"timeout": 60})
# 会话工厂 # 会话工厂
SessionFactory = sessionmaker(autocommit=False, autoflush=False, bind=Engine) SessionFactory = sessionmaker(bind=Engine)
# 多线程全局使用的数据库会话 # 多线程全局使用的数据库会话
ScopedSession = scoped_session(SessionFactory) ScopedSession = scoped_session(SessionFactory)
# 数据库锁
DBLock = threading.Lock()
def get_db(): def get_db():
""" """
@ -39,18 +34,6 @@ def get_db():
db.close() db.close()
def db_lock(func):
"""
使用DBLock加锁防止多线程同时操作数据库
装饰器
"""
def wrapper(*args, **kwargs):
with DBLock:
return func(*args, **kwargs)
return wrapper
class DbOper: class DbOper:
""" """
数据库操作基类 数据库操作基类

View File

@ -1,45 +1,65 @@
import threading
from typing import Any, Self, List from typing import Any, Self, List
from sqlalchemy.orm import as_declarative, declared_attr, Session from sqlalchemy.orm import as_declarative, declared_attr, Session
from app.db import ScopedSession
# 数据库锁
DBLock = threading.Lock()
def db_persist(func):
"""
数据库操作装饰器获取第一个输入参数db执行数据库操作后提交
"""
def wrapper(*args, **kwargs):
with DBLock:
db: Session = kwargs.get("db") or args[1]
try:
if db:
db.close()
db = ScopedSession()
result = func(*args, **kwargs)
db.commit()
except Exception as err:
db.rollback()
raise err
return result
return wrapper
@as_declarative() @as_declarative()
class Base: class Base:
id: Any id: Any
__name__: str __name__: str
@staticmethod @db_persist
def commit(db: Session):
try:
db.commit()
except Exception as err:
db.rollback()
raise err
def create(self, db: Session) -> Self: def create(self, db: Session) -> Self:
db.add(self) db.add(self)
self.commit(db)
return self return self
@classmethod @classmethod
def get(cls, db: Session, rid: int) -> Self: def get(cls, db: Session, rid: int) -> Self:
return db.query(cls).filter(cls.id == rid).first() return db.query(cls).filter(cls.id == rid).first()
@db_persist
def update(self, db: Session, payload: dict): def update(self, db: Session, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None} payload = {k: v for k, v in payload.items() if v is not None}
for key, value in payload.items(): for key, value in payload.items():
setattr(self, key, value) setattr(self, key, value)
Base.commit(db)
@classmethod @classmethod
@db_persist
def delete(cls, db: Session, rid): def delete(cls, db: Session, rid):
db.query(cls).filter(cls.id == rid).delete() db.query(cls).filter(cls.id == rid).delete()
Base.commit(db)
@classmethod @classmethod
@db_persist
def truncate(cls, db: Session): def truncate(cls, db: Session):
db.query(cls).delete() db.query(cls).delete()
Base.commit(db)
@classmethod @classmethod
def list(cls, db: Session) -> List[Self]: def list(cls, db: Session) -> List[Self]:

View File

@ -1,7 +1,7 @@
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base from app.db.models import Base, db_persist
class DownloadHistory(Base): class DownloadHistory(Base):
@ -148,6 +148,7 @@ class DownloadFiles(Base):
return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all()
@staticmethod @staticmethod
@db_persist
def delete_by_fullpath(db: Session, fullpath: str): def delete_by_fullpath(db: Session, fullpath: str):
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath, db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath,
DownloadFiles.state == 1).update( DownloadFiles.state == 1).update(
@ -155,4 +156,3 @@ class DownloadFiles(Base):
"state": 0 "state": 0
} }
) )
Base.commit(db)

View File

@ -3,7 +3,7 @@ from datetime import datetime
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base from app.db.models import Base, db_persist
class MediaServerItem(Base): class MediaServerItem(Base):
@ -45,9 +45,9 @@ class MediaServerItem(Base):
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first() return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first()
@staticmethod @staticmethod
@db_persist
def empty(db: Session, server: str): def empty(db: Session, server: str):
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
Base.commit(db)
@staticmethod @staticmethod
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):

View File

@ -1,7 +1,7 @@
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base from app.db.models import Base, db_persist
class PluginData(Base): class PluginData(Base):
@ -22,9 +22,9 @@ class PluginData(Base):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first() return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first()
@staticmethod @staticmethod
@db_persist
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
Base.commit(db)
@staticmethod @staticmethod
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):

View File

@ -3,7 +3,7 @@ from datetime import datetime
from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy import Boolean, Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base from app.db.models import Base, db_persist
class Site(Base): class Site(Base):
@ -59,6 +59,6 @@ class Site(Base):
return db.query(Site).order_by(Site.pri).all() return db.query(Site).order_by(Site.pri).all()
@staticmethod @staticmethod
@db_persist
def reset(db: Session): def reset(db: Session):
db.query(Site).delete() db.query(Site).delete()
Base.commit(db)

View File

@ -3,7 +3,7 @@ import time
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func from sqlalchemy import Column, Integer, String, Sequence, Boolean, func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base from app.db.models import Base, db_persist
class TransferHistory(Base): class TransferHistory(Base):
@ -154,10 +154,10 @@ class TransferHistory(Base):
TransferHistory.type == mtype).first() TransferHistory.type == mtype).first()
@staticmethod @staticmethod
@db_persist
def update_download_hash(db: Session, historyid: int = None, download_hash: str = None): def update_download_hash(db: Session, historyid: int = None, download_hash: str = None):
db.query(TransferHistory).filter(TransferHistory.id == historyid).update( db.query(TransferHistory).filter(TransferHistory.id == historyid).update(
{ {
"download_hash": download_hash "download_hash": download_hash
} }
) )
Base.commit(db)

View File

@ -2,7 +2,7 @@ import time
from typing import Tuple, List from typing import Tuple, List
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.db import DbOper, db_lock from app.db import DbOper
from app.db.models.subscribe import Subscribe from app.db.models.subscribe import Subscribe
@ -11,7 +11,6 @@ class SubscribeOper(DbOper):
订阅管理 订阅管理
""" """
@db_lock
def add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]: def add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]:
""" """
新增订阅 新增订阅
@ -58,14 +57,12 @@ class SubscribeOper(DbOper):
return Subscribe.get_by_state(self._db, state) return Subscribe.get_by_state(self._db, state)
return Subscribe.list(self._db) return Subscribe.list(self._db)
@db_lock
def delete(self, sid: int): def delete(self, sid: int):
""" """
删除订阅 删除订阅
""" """
Subscribe.delete(self._db, rid=sid) Subscribe.delete(self._db, rid=sid)
@db_lock
def update(self, sid: int, payload: dict) -> Subscribe: def update(self, sid: int, payload: dict) -> Subscribe:
""" """
更新订阅 更新订阅

View File

@ -250,7 +250,7 @@ class ChineseSubFinder(_PluginBase):
logger.warn("ChineseSubFinder下载字幕出错%s" % message) logger.warn("ChineseSubFinder下载字幕出错%s" % message)
else: else:
logger.info("ChineseSubFinder任务添加成功%s" % job_id) logger.info("ChineseSubFinder任务添加成功%s" % job_id)
else: elif res.status_code != 200:
logger.warn(f"ChineseSubFinder调用出错{res.status_code} - {res.reason}") logger.warn(f"ChineseSubFinder调用出错{res.status_code} - {res.reason}")
except Exception as e: except Exception as e:
logger.error("连接ChineseSubFinder出错" + str(e)) logger.error("连接ChineseSubFinder出错" + str(e))

View File

@ -577,9 +577,9 @@ class DirMonitor(_PluginBase):
""" """
从表中获取download_hash避免连接下载器 从表中获取download_hash避免连接下载器
""" """
downloadHis = self.downloadhis.get_file_by_fullpath(src) download_file = self.downloadhis.get_file_by_fullpath(src)
if downloadHis: if download_file:
return downloadHis.download_hash return download_file.download_hash
return None return None
def get_state(self) -> bool: def get_state(self) -> bool: