diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index 9411fa0c..fc7b116c 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -1,6 +1,6 @@ from typing import Any -from sqlalchemy.orm import as_declarative, declared_attr +from sqlalchemy.orm import as_declarative, declared_attr, Session @as_declarative() @@ -8,33 +8,41 @@ class Base: id: Any __name__: str - def create(self, db): + @staticmethod + def commit(db: Session): + try: + db.commit() + except Exception as err: + db.rollback() + raise err + + def create(self, db: Session): db.add(self) - db.commit() + self.commit(db) return self @classmethod - def get(cls, db, rid: int): + def get(cls, db: Session, rid: int): return db.query(cls).filter(cls.id == rid).first() - def update(self, db, payload: dict): + def update(self, db: Session, payload: dict): payload = {k: v for k, v in payload.items() if v is not None} for key, value in payload.items(): setattr(self, key, value) - db.commit() + Base.commit(db) @classmethod - def delete(cls, db, rid): + def delete(cls, db: Session, rid): db.query(cls).filter(cls.id == rid).delete() - db.commit() + Base.commit(db) @classmethod - def truncate(cls, db): + def truncate(cls, db: Session): db.query(cls).delete() - db.commit() + Base.commit(db) @classmethod - def list(cls, db): + def list(cls, db: Session): return db.query(cls).all() def to_dict(self): diff --git a/app/db/models/downloadhistory.py b/app/db/models/downloadhistory.py index 0999c86b..dc3dbd42 100644 --- a/app/db/models/downloadhistory.py +++ b/app/db/models/downloadhistory.py @@ -136,4 +136,4 @@ class DownloadFiles(Base): "state": 0 } ) - db.commit() + Base.commit(db) diff --git a/app/db/models/mediaserver.py b/app/db/models/mediaserver.py index ca77ddc7..1714eda7 100644 --- a/app/db/models/mediaserver.py +++ b/app/db/models/mediaserver.py @@ -47,7 +47,7 @@ class MediaServerItem(Base): @staticmethod def empty(db: Session, server: str): db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() - db.commit() + Base.commit(db) @staticmethod def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): diff --git a/app/db/models/plugin.py b/app/db/models/plugin.py index e4af5e58..e060512a 100644 --- a/app/db/models/plugin.py +++ b/app/db/models/plugin.py @@ -24,7 +24,7 @@ class PluginData(Base): @staticmethod def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() - db.commit() + Base.commit(db) @staticmethod def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): diff --git a/app/db/models/site.py b/app/db/models/site.py index b9defde5..ccb735fb 100644 --- a/app/db/models/site.py +++ b/app/db/models/site.py @@ -61,4 +61,4 @@ class Site(Base): @staticmethod def reset(db: Session): db.query(Site).delete() - db.commit() + Base.commit(db) diff --git a/app/db/models/transferhistory.py b/app/db/models/transferhistory.py index b004877b..dd8eb392 100644 --- a/app/db/models/transferhistory.py +++ b/app/db/models/transferhistory.py @@ -122,4 +122,4 @@ class TransferHistory(Base): "download_hash": download_hash } ) - db.commit() + Base.commit(db)