From b5fc6cdd1e78c8bca879db59f287ddf666135cec Mon Sep 17 00:00:00 2001 From: jxxghp Date: Tue, 5 Sep 2023 18:19:02 +0800 Subject: [PATCH] =?UTF-8?q?fix=20=E7=BB=9F=E4=B8=80=E5=A4=84=E7=90=86db?= =?UTF-8?q?=E4=BA=8B=E5=8A=A1=E5=9B=9E=E6=BB=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/db/models/__init__.py | 30 +++++++++++++++++++----------- app/db/models/downloadhistory.py | 2 +- app/db/models/mediaserver.py | 2 +- app/db/models/plugin.py | 2 +- app/db/models/site.py | 2 +- app/db/models/transferhistory.py | 2 +- 6 files changed, 24 insertions(+), 16 deletions(-) diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index 9411fa0c..fc7b116c 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -1,6 +1,6 @@ from typing import Any -from sqlalchemy.orm import as_declarative, declared_attr +from sqlalchemy.orm import as_declarative, declared_attr, Session @as_declarative() @@ -8,33 +8,41 @@ class Base: id: Any __name__: str - def create(self, db): + @staticmethod + def commit(db: Session): + try: + db.commit() + except Exception as err: + db.rollback() + raise err + + def create(self, db: Session): db.add(self) - db.commit() + self.commit(db) return self @classmethod - def get(cls, db, rid: int): + def get(cls, db: Session, rid: int): return db.query(cls).filter(cls.id == rid).first() - def update(self, db, payload: dict): + def update(self, db: Session, payload: dict): payload = {k: v for k, v in payload.items() if v is not None} for key, value in payload.items(): setattr(self, key, value) - db.commit() + Base.commit(db) @classmethod - def delete(cls, db, rid): + def delete(cls, db: Session, rid): db.query(cls).filter(cls.id == rid).delete() - db.commit() + Base.commit(db) @classmethod - def truncate(cls, db): + def truncate(cls, db: Session): db.query(cls).delete() - db.commit() + Base.commit(db) @classmethod - def list(cls, db): + def list(cls, db: Session): return db.query(cls).all() def to_dict(self): diff --git a/app/db/models/downloadhistory.py b/app/db/models/downloadhistory.py index 0999c86b..dc3dbd42 100644 --- a/app/db/models/downloadhistory.py +++ b/app/db/models/downloadhistory.py @@ -136,4 +136,4 @@ class DownloadFiles(Base): "state": 0 } ) - db.commit() + Base.commit(db) diff --git a/app/db/models/mediaserver.py b/app/db/models/mediaserver.py index ca77ddc7..1714eda7 100644 --- a/app/db/models/mediaserver.py +++ b/app/db/models/mediaserver.py @@ -47,7 +47,7 @@ class MediaServerItem(Base): @staticmethod def empty(db: Session, server: str): db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() - db.commit() + Base.commit(db) @staticmethod def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): diff --git a/app/db/models/plugin.py b/app/db/models/plugin.py index e4af5e58..e060512a 100644 --- a/app/db/models/plugin.py +++ b/app/db/models/plugin.py @@ -24,7 +24,7 @@ class PluginData(Base): @staticmethod def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() - db.commit() + Base.commit(db) @staticmethod def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): diff --git a/app/db/models/site.py b/app/db/models/site.py index b9defde5..ccb735fb 100644 --- a/app/db/models/site.py +++ b/app/db/models/site.py @@ -61,4 +61,4 @@ class Site(Base): @staticmethod def reset(db: Session): db.query(Site).delete() - db.commit() + Base.commit(db) diff --git a/app/db/models/transferhistory.py b/app/db/models/transferhistory.py index b004877b..dd8eb392 100644 --- a/app/db/models/transferhistory.py +++ b/app/db/models/transferhistory.py @@ -122,4 +122,4 @@ class TransferHistory(Base): "download_hash": download_hash } ) - db.commit() + Base.commit(db)