fix db session
This commit is contained in:
parent
68e16d18fe
commit
2e4536edb6
@ -1,5 +1,3 @@
|
|||||||
import threading
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine, QueuePool
|
from sqlalchemy import create_engine, QueuePool
|
||||||
from sqlalchemy.orm import sessionmaker, Session, scoped_session
|
from sqlalchemy.orm import sessionmaker, Session, scoped_session
|
||||||
|
|
||||||
@ -16,14 +14,11 @@ Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db",
|
|||||||
max_overflow=10,
|
max_overflow=10,
|
||||||
connect_args={"timeout": 60})
|
connect_args={"timeout": 60})
|
||||||
# 会话工厂
|
# 会话工厂
|
||||||
SessionFactory = sessionmaker(autocommit=False, autoflush=False, bind=Engine)
|
SessionFactory = sessionmaker(bind=Engine)
|
||||||
|
|
||||||
# 多线程全局使用的数据库会话
|
# 多线程全局使用的数据库会话
|
||||||
ScopedSession = scoped_session(SessionFactory)
|
ScopedSession = scoped_session(SessionFactory)
|
||||||
|
|
||||||
# 数据库锁
|
|
||||||
DBLock = threading.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
def get_db():
|
def get_db():
|
||||||
"""
|
"""
|
||||||
@ -39,18 +34,6 @@ def get_db():
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
def db_lock(func):
|
|
||||||
"""
|
|
||||||
使用DBLock加锁,防止多线程同时操作数据库
|
|
||||||
装饰器
|
|
||||||
"""
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
with DBLock:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class DbOper:
|
class DbOper:
|
||||||
"""
|
"""
|
||||||
数据库操作基类
|
数据库操作基类
|
||||||
|
@ -1,45 +1,65 @@
|
|||||||
|
import threading
|
||||||
from typing import Any, Self, List
|
from typing import Any, Self, List
|
||||||
|
|
||||||
from sqlalchemy.orm import as_declarative, declared_attr, Session
|
from sqlalchemy.orm import as_declarative, declared_attr, Session
|
||||||
|
|
||||||
|
from app.db import ScopedSession
|
||||||
|
|
||||||
|
# 数据库锁
|
||||||
|
DBLock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def db_persist(func):
|
||||||
|
"""
|
||||||
|
数据库操作装饰器,获取第一个输入参数db,执行数据库操作后提交
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with DBLock:
|
||||||
|
db: Session = kwargs.get("db") or args[1]
|
||||||
|
try:
|
||||||
|
if db:
|
||||||
|
db.close()
|
||||||
|
db = ScopedSession()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
db.commit()
|
||||||
|
except Exception as err:
|
||||||
|
db.rollback()
|
||||||
|
raise err
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@as_declarative()
|
@as_declarative()
|
||||||
class Base:
|
class Base:
|
||||||
id: Any
|
id: Any
|
||||||
__name__: str
|
__name__: str
|
||||||
|
|
||||||
@staticmethod
|
@db_persist
|
||||||
def commit(db: Session):
|
|
||||||
try:
|
|
||||||
db.commit()
|
|
||||||
except Exception as err:
|
|
||||||
db.rollback()
|
|
||||||
raise err
|
|
||||||
|
|
||||||
def create(self, db: Session) -> Self:
|
def create(self, db: Session) -> Self:
|
||||||
db.add(self)
|
db.add(self)
|
||||||
self.commit(db)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, db: Session, rid: int) -> Self:
|
def get(cls, db: Session, rid: int) -> Self:
|
||||||
return db.query(cls).filter(cls.id == rid).first()
|
return db.query(cls).filter(cls.id == rid).first()
|
||||||
|
|
||||||
|
@db_persist
|
||||||
def update(self, db: Session, payload: dict):
|
def update(self, db: Session, payload: dict):
|
||||||
payload = {k: v for k, v in payload.items() if v is not None}
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
for key, value in payload.items():
|
for key, value in payload.items():
|
||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
Base.commit(db)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@db_persist
|
||||||
def delete(cls, db: Session, rid):
|
def delete(cls, db: Session, rid):
|
||||||
db.query(cls).filter(cls.id == rid).delete()
|
db.query(cls).filter(cls.id == rid).delete()
|
||||||
Base.commit(db)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@db_persist
|
||||||
def truncate(cls, db: Session):
|
def truncate(cls, db: Session):
|
||||||
db.query(cls).delete()
|
db.query(cls).delete()
|
||||||
Base.commit(db)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list(cls, db: Session) -> List[Self]:
|
def list(cls, db: Session) -> List[Self]:
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from sqlalchemy import Column, Integer, String, Sequence
|
from sqlalchemy import Column, Integer, String, Sequence
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.db.models import Base
|
from app.db.models import Base, db_persist
|
||||||
|
|
||||||
|
|
||||||
class DownloadHistory(Base):
|
class DownloadHistory(Base):
|
||||||
@ -148,6 +148,7 @@ class DownloadFiles(Base):
|
|||||||
return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all()
|
return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@db_persist
|
||||||
def delete_by_fullpath(db: Session, fullpath: str):
|
def delete_by_fullpath(db: Session, fullpath: str):
|
||||||
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath,
|
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath,
|
||||||
DownloadFiles.state == 1).update(
|
DownloadFiles.state == 1).update(
|
||||||
@ -155,4 +156,3 @@ class DownloadFiles(Base):
|
|||||||
"state": 0
|
"state": 0
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
Base.commit(db)
|
|
||||||
|
@ -3,7 +3,7 @@ from datetime import datetime
|
|||||||
from sqlalchemy import Column, Integer, String, Sequence
|
from sqlalchemy import Column, Integer, String, Sequence
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.db.models import Base
|
from app.db.models import Base, db_persist
|
||||||
|
|
||||||
|
|
||||||
class MediaServerItem(Base):
|
class MediaServerItem(Base):
|
||||||
@ -45,9 +45,9 @@ class MediaServerItem(Base):
|
|||||||
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first()
|
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@db_persist
|
||||||
def empty(db: Session, server: str):
|
def empty(db: Session, server: str):
|
||||||
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
|
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
|
||||||
Base.commit(db)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):
|
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from sqlalchemy import Column, Integer, String, Sequence
|
from sqlalchemy import Column, Integer, String, Sequence
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.db.models import Base
|
from app.db.models import Base, db_persist
|
||||||
|
|
||||||
|
|
||||||
class PluginData(Base):
|
class PluginData(Base):
|
||||||
@ -22,9 +22,9 @@ class PluginData(Base):
|
|||||||
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first()
|
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@db_persist
|
||||||
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
|
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
|
||||||
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
|
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
|
||||||
Base.commit(db)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):
|
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):
|
||||||
|
@ -3,7 +3,7 @@ from datetime import datetime
|
|||||||
from sqlalchemy import Boolean, Column, Integer, String, Sequence
|
from sqlalchemy import Boolean, Column, Integer, String, Sequence
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.db.models import Base
|
from app.db.models import Base, db_persist
|
||||||
|
|
||||||
|
|
||||||
class Site(Base):
|
class Site(Base):
|
||||||
@ -59,6 +59,6 @@ class Site(Base):
|
|||||||
return db.query(Site).order_by(Site.pri).all()
|
return db.query(Site).order_by(Site.pri).all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@db_persist
|
||||||
def reset(db: Session):
|
def reset(db: Session):
|
||||||
db.query(Site).delete()
|
db.query(Site).delete()
|
||||||
Base.commit(db)
|
|
||||||
|
@ -3,7 +3,7 @@ import time
|
|||||||
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func
|
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.db.models import Base
|
from app.db.models import Base, db_persist
|
||||||
|
|
||||||
|
|
||||||
class TransferHistory(Base):
|
class TransferHistory(Base):
|
||||||
@ -154,10 +154,10 @@ class TransferHistory(Base):
|
|||||||
TransferHistory.type == mtype).first()
|
TransferHistory.type == mtype).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@db_persist
|
||||||
def update_download_hash(db: Session, historyid: int = None, download_hash: str = None):
|
def update_download_hash(db: Session, historyid: int = None, download_hash: str = None):
|
||||||
db.query(TransferHistory).filter(TransferHistory.id == historyid).update(
|
db.query(TransferHistory).filter(TransferHistory.id == historyid).update(
|
||||||
{
|
{
|
||||||
"download_hash": download_hash
|
"download_hash": download_hash
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
Base.commit(db)
|
|
||||||
|
@ -2,7 +2,7 @@ import time
|
|||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
|
|
||||||
from app.core.context import MediaInfo
|
from app.core.context import MediaInfo
|
||||||
from app.db import DbOper, db_lock
|
from app.db import DbOper
|
||||||
from app.db.models.subscribe import Subscribe
|
from app.db.models.subscribe import Subscribe
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +11,6 @@ class SubscribeOper(DbOper):
|
|||||||
订阅管理
|
订阅管理
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@db_lock
|
|
||||||
def add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]:
|
def add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]:
|
||||||
"""
|
"""
|
||||||
新增订阅
|
新增订阅
|
||||||
@ -58,14 +57,12 @@ class SubscribeOper(DbOper):
|
|||||||
return Subscribe.get_by_state(self._db, state)
|
return Subscribe.get_by_state(self._db, state)
|
||||||
return Subscribe.list(self._db)
|
return Subscribe.list(self._db)
|
||||||
|
|
||||||
@db_lock
|
|
||||||
def delete(self, sid: int):
|
def delete(self, sid: int):
|
||||||
"""
|
"""
|
||||||
删除订阅
|
删除订阅
|
||||||
"""
|
"""
|
||||||
Subscribe.delete(self._db, rid=sid)
|
Subscribe.delete(self._db, rid=sid)
|
||||||
|
|
||||||
@db_lock
|
|
||||||
def update(self, sid: int, payload: dict) -> Subscribe:
|
def update(self, sid: int, payload: dict) -> Subscribe:
|
||||||
"""
|
"""
|
||||||
更新订阅
|
更新订阅
|
||||||
|
@ -250,7 +250,7 @@ class ChineseSubFinder(_PluginBase):
|
|||||||
logger.warn("ChineseSubFinder下载字幕出错:%s" % message)
|
logger.warn("ChineseSubFinder下载字幕出错:%s" % message)
|
||||||
else:
|
else:
|
||||||
logger.info("ChineseSubFinder任务添加成功:%s" % job_id)
|
logger.info("ChineseSubFinder任务添加成功:%s" % job_id)
|
||||||
else:
|
elif res.status_code != 200:
|
||||||
logger.warn(f"ChineseSubFinder调用出错:{res.status_code} - {res.reason}")
|
logger.warn(f"ChineseSubFinder调用出错:{res.status_code} - {res.reason}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("连接ChineseSubFinder出错:" + str(e))
|
logger.error("连接ChineseSubFinder出错:" + str(e))
|
||||||
|
@ -577,9 +577,9 @@ class DirMonitor(_PluginBase):
|
|||||||
"""
|
"""
|
||||||
从表中获取download_hash,避免连接下载器
|
从表中获取download_hash,避免连接下载器
|
||||||
"""
|
"""
|
||||||
downloadHis = self.downloadhis.get_file_by_fullpath(src)
|
download_file = self.downloadhis.get_file_by_fullpath(src)
|
||||||
if downloadHis:
|
if download_file:
|
||||||
return downloadHis.download_hash
|
return download_file.download_hash
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_state(self) -> bool:
|
def get_state(self) -> bool:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user