fix db session
This commit is contained in:
@ -1,45 +1,65 @@
|
||||
import threading
|
||||
from typing import Any, Self, List
|
||||
|
||||
from sqlalchemy.orm import as_declarative, declared_attr, Session
|
||||
|
||||
from app.db import ScopedSession
|
||||
|
||||
# 数据库锁
|
||||
DBLock = threading.Lock()
|
||||
|
||||
|
||||
def db_persist(func):
|
||||
"""
|
||||
数据库操作装饰器,获取第一个输入参数db,执行数据库操作后提交
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
with DBLock:
|
||||
db: Session = kwargs.get("db") or args[1]
|
||||
try:
|
||||
if db:
|
||||
db.close()
|
||||
db = ScopedSession()
|
||||
result = func(*args, **kwargs)
|
||||
db.commit()
|
||||
except Exception as err:
|
||||
db.rollback()
|
||||
raise err
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@as_declarative()
|
||||
class Base:
|
||||
id: Any
|
||||
__name__: str
|
||||
|
||||
@staticmethod
|
||||
def commit(db: Session):
|
||||
try:
|
||||
db.commit()
|
||||
except Exception as err:
|
||||
db.rollback()
|
||||
raise err
|
||||
|
||||
@db_persist
|
||||
def create(self, db: Session) -> Self:
|
||||
db.add(self)
|
||||
self.commit(db)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get(cls, db: Session, rid: int) -> Self:
|
||||
return db.query(cls).filter(cls.id == rid).first()
|
||||
|
||||
@db_persist
|
||||
def update(self, db: Session, payload: dict):
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
for key, value in payload.items():
|
||||
setattr(self, key, value)
|
||||
Base.commit(db)
|
||||
|
||||
@classmethod
|
||||
@db_persist
|
||||
def delete(cls, db: Session, rid):
|
||||
db.query(cls).filter(cls.id == rid).delete()
|
||||
Base.commit(db)
|
||||
|
||||
@classmethod
|
||||
@db_persist
|
||||
def truncate(cls, db: Session):
|
||||
db.query(cls).delete()
|
||||
Base.commit(db)
|
||||
|
||||
@classmethod
|
||||
def list(cls, db: Session) -> List[Self]:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from sqlalchemy import Column, Integer, String, Sequence
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.models import Base
|
||||
from app.db.models import Base, db_persist
|
||||
|
||||
|
||||
class DownloadHistory(Base):
|
||||
@ -148,6 +148,7 @@ class DownloadFiles(Base):
|
||||
return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all()
|
||||
|
||||
@staticmethod
|
||||
@db_persist
|
||||
def delete_by_fullpath(db: Session, fullpath: str):
|
||||
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath,
|
||||
DownloadFiles.state == 1).update(
|
||||
@ -155,4 +156,3 @@ class DownloadFiles(Base):
|
||||
"state": 0
|
||||
}
|
||||
)
|
||||
Base.commit(db)
|
||||
|
@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Sequence
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.models import Base
|
||||
from app.db.models import Base, db_persist
|
||||
|
||||
|
||||
class MediaServerItem(Base):
|
||||
@ -45,9 +45,9 @@ class MediaServerItem(Base):
|
||||
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first()
|
||||
|
||||
@staticmethod
|
||||
@db_persist
|
||||
def empty(db: Session, server: str):
|
||||
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
|
||||
Base.commit(db)
|
||||
|
||||
@staticmethod
|
||||
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):
|
||||
|
@ -1,7 +1,7 @@
|
||||
from sqlalchemy import Column, Integer, String, Sequence
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.models import Base
|
||||
from app.db.models import Base, db_persist
|
||||
|
||||
|
||||
class PluginData(Base):
|
||||
@ -22,9 +22,9 @@ class PluginData(Base):
|
||||
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first()
|
||||
|
||||
@staticmethod
|
||||
@db_persist
|
||||
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
|
||||
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
|
||||
Base.commit(db)
|
||||
|
||||
@staticmethod
|
||||
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):
|
||||
|
@ -3,7 +3,7 @@ from datetime import datetime
|
||||
from sqlalchemy import Boolean, Column, Integer, String, Sequence
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.models import Base
|
||||
from app.db.models import Base, db_persist
|
||||
|
||||
|
||||
class Site(Base):
|
||||
@ -59,6 +59,6 @@ class Site(Base):
|
||||
return db.query(Site).order_by(Site.pri).all()
|
||||
|
||||
@staticmethod
|
||||
@db_persist
|
||||
def reset(db: Session):
|
||||
db.query(Site).delete()
|
||||
Base.commit(db)
|
||||
|
@ -3,7 +3,7 @@ import time
|
||||
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.models import Base
|
||||
from app.db.models import Base, db_persist
|
||||
|
||||
|
||||
class TransferHistory(Base):
|
||||
@ -154,10 +154,10 @@ class TransferHistory(Base):
|
||||
TransferHistory.type == mtype).first()
|
||||
|
||||
@staticmethod
|
||||
@db_persist
|
||||
def update_download_hash(db: Session, historyid: int = None, download_hash: str = None):
|
||||
db.query(TransferHistory).filter(TransferHistory.id == historyid).update(
|
||||
{
|
||||
"download_hash": download_hash
|
||||
}
|
||||
)
|
||||
Base.commit(db)
|
||||
|
Reference in New Issue
Block a user