fix 统一处理db事务回滚

This commit is contained in:
jxxghp
2023-09-05 18:19:02 +08:00
parent 51b959cff8
commit b5fc6cdd1e
6 changed files with 24 additions and 16 deletions

View File

@ -1,6 +1,6 @@
from typing import Any from typing import Any
from sqlalchemy.orm import as_declarative, declared_attr from sqlalchemy.orm import as_declarative, declared_attr, Session
@as_declarative() @as_declarative()
@ -8,33 +8,41 @@ class Base:
id: Any id: Any
__name__: str __name__: str
def create(self, db): @staticmethod
def commit(db: Session):
try:
db.commit()
except Exception as err:
db.rollback()
raise err
def create(self, db: Session):
db.add(self) db.add(self)
db.commit() self.commit(db)
return self return self
@classmethod @classmethod
def get(cls, db, rid: int): def get(cls, db: Session, rid: int):
return db.query(cls).filter(cls.id == rid).first() return db.query(cls).filter(cls.id == rid).first()
def update(self, db, payload: dict): def update(self, db: Session, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None} payload = {k: v for k, v in payload.items() if v is not None}
for key, value in payload.items(): for key, value in payload.items():
setattr(self, key, value) setattr(self, key, value)
db.commit() Base.commit(db)
@classmethod @classmethod
def delete(cls, db, rid): def delete(cls, db: Session, rid):
db.query(cls).filter(cls.id == rid).delete() db.query(cls).filter(cls.id == rid).delete()
db.commit() Base.commit(db)
@classmethod @classmethod
def truncate(cls, db): def truncate(cls, db: Session):
db.query(cls).delete() db.query(cls).delete()
db.commit() Base.commit(db)
@classmethod @classmethod
def list(cls, db): def list(cls, db: Session):
return db.query(cls).all() return db.query(cls).all()
def to_dict(self): def to_dict(self):

View File

@ -136,4 +136,4 @@ class DownloadFiles(Base):
"state": 0 "state": 0
} }
) )
db.commit() Base.commit(db)

View File

@ -47,7 +47,7 @@ class MediaServerItem(Base):
@staticmethod @staticmethod
def empty(db: Session, server: str): def empty(db: Session, server: str):
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
db.commit() Base.commit(db)
@staticmethod @staticmethod
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):

View File

@ -24,7 +24,7 @@ class PluginData(Base):
@staticmethod @staticmethod
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
db.commit() Base.commit(db)
@staticmethod @staticmethod
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):

View File

@ -61,4 +61,4 @@ class Site(Base):
@staticmethod @staticmethod
def reset(db: Session): def reset(db: Session):
db.query(Site).delete() db.query(Site).delete()
db.commit() Base.commit(db)

View File

@ -122,4 +122,4 @@ class TransferHistory(Base):
"download_hash": download_hash "download_hash": download_hash
} }
) )
db.commit() Base.commit(db)