From 2e4536edb6e396ce4cc21a1e9db117df23679cb8 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Wed, 18 Oct 2023 18:30:09 +0800 Subject: [PATCH] fix db session --- app/db/__init__.py | 19 +--------- app/db/models/__init__.py | 44 +++++++++++++++++------- app/db/models/downloadhistory.py | 4 +-- app/db/models/mediaserver.py | 4 +-- app/db/models/plugin.py | 4 +-- app/db/models/site.py | 4 +-- app/db/models/transferhistory.py | 4 +-- app/db/subscribe_oper.py | 5 +-- app/plugins/chinesesubfinder/__init__.py | 2 +- app/plugins/dirmonitor/__init__.py | 6 ++-- 10 files changed, 48 insertions(+), 48 deletions(-) diff --git a/app/db/__init__.py b/app/db/__init__.py index 1e88811a..a959b0ae 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,5 +1,3 @@ -import threading - from sqlalchemy import create_engine, QueuePool from sqlalchemy.orm import sessionmaker, Session, scoped_session @@ -16,14 +14,11 @@ Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db", max_overflow=10, connect_args={"timeout": 60}) # 会话工厂 -SessionFactory = sessionmaker(autocommit=False, autoflush=False, bind=Engine) +SessionFactory = sessionmaker(bind=Engine) # 多线程全局使用的数据库会话 ScopedSession = scoped_session(SessionFactory) -# 数据库锁 -DBLock = threading.Lock() - def get_db(): """ @@ -39,18 +34,6 @@ def get_db(): db.close() -def db_lock(func): - """ - 使用DBLock加锁,防止多线程同时操作数据库 - 装饰器 - """ - def wrapper(*args, **kwargs): - with DBLock: - return func(*args, **kwargs) - - return wrapper - - class DbOper: """ 数据库操作基类 diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index 086c7f27..5828c6f6 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -1,45 +1,65 @@ +import threading from typing import Any, Self, List from sqlalchemy.orm import as_declarative, declared_attr, Session +from app.db import ScopedSession + +# 数据库锁 +DBLock = threading.Lock() + + +def db_persist(func): + """ + 数据库操作装饰器,获取第一个输入参数db,执行数据库操作后提交 + """ + + def wrapper(*args, **kwargs): + with DBLock: + db: Session = kwargs.get("db") or args[1] + try: + if db: + db.close() + db = ScopedSession() + result = func(*args, **kwargs) + db.commit() + except Exception as err: + db.rollback() + raise err + return result + + return wrapper + @as_declarative() class Base: id: Any __name__: str - @staticmethod - def commit(db: Session): - try: - db.commit() - except Exception as err: - db.rollback() - raise err - + @db_persist def create(self, db: Session) -> Self: db.add(self) - self.commit(db) return self @classmethod def get(cls, db: Session, rid: int) -> Self: return db.query(cls).filter(cls.id == rid).first() + @db_persist def update(self, db: Session, payload: dict): payload = {k: v for k, v in payload.items() if v is not None} for key, value in payload.items(): setattr(self, key, value) - Base.commit(db) @classmethod + @db_persist def delete(cls, db: Session, rid): db.query(cls).filter(cls.id == rid).delete() - Base.commit(db) @classmethod + @db_persist def truncate(cls, db: Session): db.query(cls).delete() - Base.commit(db) @classmethod def list(cls, db: Session) -> List[Self]: diff --git a/app/db/models/downloadhistory.py b/app/db/models/downloadhistory.py index 25e5effe..e249fb25 100644 --- a/app/db/models/downloadhistory.py +++ b/app/db/models/downloadhistory.py @@ -1,7 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base +from app.db.models import Base, db_persist class DownloadHistory(Base): @@ -148,6 +148,7 @@ class DownloadFiles(Base): return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() @staticmethod + @db_persist def delete_by_fullpath(db: Session, fullpath: str): db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath, DownloadFiles.state == 1).update( @@ -155,4 +156,3 @@ class DownloadFiles(Base): "state": 0 } ) - Base.commit(db) diff --git a/app/db/models/mediaserver.py b/app/db/models/mediaserver.py index 1714eda7..ee0e06e1 100644 --- a/app/db/models/mediaserver.py +++ b/app/db/models/mediaserver.py @@ -3,7 +3,7 @@ from datetime import datetime from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base +from app.db.models import Base, db_persist class MediaServerItem(Base): @@ -45,9 +45,9 @@ class MediaServerItem(Base): return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first() @staticmethod + @db_persist def empty(db: Session, server: str): db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() - Base.commit(db) @staticmethod def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): diff --git a/app/db/models/plugin.py b/app/db/models/plugin.py index e060512a..d1936a84 100644 --- a/app/db/models/plugin.py +++ b/app/db/models/plugin.py @@ -1,7 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base +from app.db.models import Base, db_persist class PluginData(Base): @@ -22,9 +22,9 @@ class PluginData(Base): return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first() @staticmethod + @db_persist def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() - Base.commit(db) @staticmethod def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): diff --git a/app/db/models/site.py b/app/db/models/site.py index ccb735fb..ced51120 100644 --- a/app/db/models/site.py +++ b/app/db/models/site.py @@ -3,7 +3,7 @@ from datetime import datetime from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base +from app.db.models import Base, db_persist class Site(Base): @@ -59,6 +59,6 @@ class Site(Base): return db.query(Site).order_by(Site.pri).all() @staticmethod + @db_persist def reset(db: Session): db.query(Site).delete() - Base.commit(db) diff --git a/app/db/models/transferhistory.py b/app/db/models/transferhistory.py index 0453595d..31c55f5c 100644 --- a/app/db/models/transferhistory.py +++ b/app/db/models/transferhistory.py @@ -3,7 +3,7 @@ import time from sqlalchemy import Column, Integer, String, Sequence, Boolean, func from sqlalchemy.orm import Session -from app.db.models import Base +from app.db.models import Base, db_persist class TransferHistory(Base): @@ -154,10 +154,10 @@ class TransferHistory(Base): TransferHistory.type == mtype).first() @staticmethod + @db_persist def update_download_hash(db: Session, historyid: int = None, download_hash: str = None): db.query(TransferHistory).filter(TransferHistory.id == historyid).update( { "download_hash": download_hash } ) - Base.commit(db) diff --git a/app/db/subscribe_oper.py b/app/db/subscribe_oper.py index aac0f122..574f2791 100644 --- a/app/db/subscribe_oper.py +++ b/app/db/subscribe_oper.py @@ -2,7 +2,7 @@ import time from typing import Tuple, List from app.core.context import MediaInfo -from app.db import DbOper, db_lock +from app.db import DbOper from app.db.models.subscribe import Subscribe @@ -11,7 +11,6 @@ class SubscribeOper(DbOper): 订阅管理 """ - @db_lock def add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]: """ 新增订阅 @@ -58,14 +57,12 @@ class SubscribeOper(DbOper): return Subscribe.get_by_state(self._db, state) return Subscribe.list(self._db) - @db_lock def delete(self, sid: int): """ 删除订阅 """ Subscribe.delete(self._db, rid=sid) - @db_lock def update(self, sid: int, payload: dict) -> Subscribe: """ 更新订阅 diff --git a/app/plugins/chinesesubfinder/__init__.py b/app/plugins/chinesesubfinder/__init__.py index f6bf3c73..7550e440 100644 --- a/app/plugins/chinesesubfinder/__init__.py +++ b/app/plugins/chinesesubfinder/__init__.py @@ -250,7 +250,7 @@ class ChineseSubFinder(_PluginBase): logger.warn("ChineseSubFinder下载字幕出错:%s" % message) else: logger.info("ChineseSubFinder任务添加成功:%s" % job_id) - else: + elif res.status_code != 200: logger.warn(f"ChineseSubFinder调用出错:{res.status_code} - {res.reason}") except Exception as e: logger.error("连接ChineseSubFinder出错:" + str(e)) diff --git a/app/plugins/dirmonitor/__init__.py b/app/plugins/dirmonitor/__init__.py index 3e953252..b0e569c2 100644 --- a/app/plugins/dirmonitor/__init__.py +++ b/app/plugins/dirmonitor/__init__.py @@ -577,9 +577,9 @@ class DirMonitor(_PluginBase): """ 从表中获取download_hash,避免连接下载器 """ - downloadHis = self.downloadhis.get_file_by_fullpath(src) - if downloadHis: - return downloadHis.download_hash + download_file = self.downloadhis.get_file_by_fullpath(src) + if download_file: + return download_file.download_hash return None def get_state(self) -> bool: