From a911bab7b0f9eb5ce94d7ffe6aced803bc206b09 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 19 Oct 2023 16:58:38 +0800 Subject: [PATCH 1/6] fix db session --- app/db/__init__.py | 112 ++++++++++++++++++++++++++++--- app/db/models/__init__.py | 40 +++-------- app/db/models/downloadhistory.py | 74 ++++++++++++-------- app/db/models/mediaserver.py | 8 ++- app/db/models/plugin.py | 14 ++-- app/db/models/site.py | 14 ++-- app/db/models/siteicon.py | 2 + app/db/models/subscribe.py | 19 ++++-- app/db/models/systemconfig.py | 3 + app/db/models/transferhistory.py | 82 +++++++++++++--------- app/db/models/user.py | 4 ++ 11 files changed, 258 insertions(+), 114 deletions(-) diff --git a/app/db/__init__.py b/app/db/__init__.py index dda252a2..3f9a436c 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,13 +1,10 @@ -import threading +from typing import Tuple from sqlalchemy import create_engine, QueuePool from sqlalchemy.orm import sessionmaker, Session, scoped_session from app.core.config import settings -# 数据库锁 -DBLock = threading.Lock() - # 数据库引擎 Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db", pool_pre_ping=True, @@ -27,7 +24,7 @@ ScopedSession = scoped_session(SessionFactory) def get_db(): """ - 获取数据库会话 + 获取数据库会话,用于WEB请求 :return: Session """ db = None @@ -39,6 +36,105 @@ def get_db(): db.close() +def get_args_db(args: tuple, kwargs: dict): + """ + 从参数中获取数据库Session对象 + """ + db = None + if args: + for arg in args: + if isinstance(arg, Session): + db = arg + break + if kwargs: + for key, value in kwargs.items(): + if isinstance(value, Session): + db = value + break + return db + + +def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]: + """ + 更新参数中的数据库Session对象,关键字传参时更新db的值,否则更新第1或第2个参数 + """ + if kwargs: + kwargs['db'] = db + elif args: + if args[0] is None: + args = (db, *args[1:]) + else: + args = (args[0], db, *args[2:]) + return args, kwargs + + +def db_update(func): + """ + 数据库更新类操作装饰器,第一个参数必须是数据库会话或存在db参数 + """ + + def wrapper(*args, **kwargs): + # 是否关闭数据库会话 + _close_db = False + # 从参数中获取数据库会话 + db = get_args_db(args, kwargs) + if not db: + # 如果没有获取到数据库会话,创建一个 + db = ScopedSession() + # 标记需要关闭数据库会话 + _close_db = True + # 更新参数中的数据库会话 + args, kwargs = update_args_db(args, kwargs, db) + try: + # 执行函数 + result = func(*args, **kwargs) + # 提交事务 + db.commit() + except Exception as err: + # 回滚事务 + db.rollback() + raise err + finally: + # 关闭数据库会话 + if _close_db: + db.close() + return result + + return wrapper + + +def db_query(func): + """ + 数据库查询操作装饰器,第一个参数必须是数据库会话或存在db参数 + 注意:db.query列表数据时,需要转换为list返回 + """ + + def wrapper(*args, **kwargs): + # 是否关闭数据库会话 + _close_db = False + # 从参数中获取数据库会话 + db = get_args_db(args, kwargs) + if not db: + # 如果没有获取到数据库会话,创建一个 + db = ScopedSession() + # 标记需要关闭数据库会话 + _close_db = True + # 更新参数中的数据库会话 + args, kwargs = update_args_db(args, kwargs, db) + try: + # 执行函数 + result = func(*args, **kwargs) + except Exception as err: + raise err + finally: + # 关闭数据库会话 + if _close_db: + db.close() + return result + + return wrapper + + class DbOper: """ 数据库操作基类 @@ -46,8 +142,4 @@ class DbOper: _db: Session = None def __init__(self, db: Session = None): - if db: - self._db = db - else: - with DBLock: - self._db = ScopedSession() + self._db = db diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index d06ce301..e5e60eec 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -2,31 +2,7 @@ from typing import Any, Self, List from sqlalchemy.orm import as_declarative, declared_attr, Session -from app.db import DBLock - - -def db_persist(func): - """ - 数据库操作装饰器,获取第一个输入参数db,执行数据库操作后提交 - """ - - def wrapper(*args, **kwargs): - with DBLock: - db: Session = kwargs.get("db") - if not db: - for arg in args: - if isinstance(arg, Session): - db = arg - break - try: - result = func(*args, **kwargs) - db.commit() - except Exception as err: - db.rollback() - raise err - return result - - return wrapper +from app.db import db_update, db_query @as_declarative() @@ -34,34 +10,38 @@ class Base: id: Any __name__: str - @db_persist + @db_update def create(self, db: Session) -> Self: db.add(self) + db.refresh(self) return self @classmethod + @db_query def get(cls, db: Session, rid: int) -> Self: return db.query(cls).filter(cls.id == rid).first() - @db_persist + @db_update def update(self, db: Session, payload: dict): payload = {k: v for k, v in payload.items() if v is not None} for key, value in payload.items(): setattr(self, key, value) @classmethod - @db_persist + @db_update def delete(cls, db: Session, rid): db.query(cls).filter(cls.id == rid).delete() @classmethod - @db_persist + @db_update def truncate(cls, db: Session): db.query(cls).delete() @classmethod + @db_query def list(cls, db: Session) -> List[Self]: - return db.query(cls).all() + result = db.query(cls).all() + return list(result) def to_dict(self): return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} diff --git a/app/db/models/downloadhistory.py b/app/db/models/downloadhistory.py index e249fb25..58358c03 100644 --- a/app/db/models/downloadhistory.py +++ b/app/db/models/downloadhistory.py @@ -1,7 +1,8 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class DownloadHistory(Base): @@ -45,69 +46,80 @@ class DownloadHistory(Base): note = Column(String) @staticmethod + @db_query def get_by_hash(db: Session, download_hash: str): return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).first() @staticmethod + @db_query def list_by_page(db: Session, page: int = 1, count: int = 30): - return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all() + result = db.query(DownloadHistory).offset((page - 1) * count).limit(count).all() + return list(result) @staticmethod + @db_query def get_by_path(db: Session, path: str): return db.query(DownloadHistory).filter(DownloadHistory.path == path).first() @staticmethod + @db_query def get_last_by(db: Session, mtype: str = None, title: str = None, year: int = None, season: str = None, episode: str = None, tmdbid: int = None): """ 据tmdbid、season、season_episode查询转移记录 """ + result = None if tmdbid and not season and not episode: - return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).order_by( DownloadHistory.id.desc()).all() if tmdbid and season and not episode: - return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, - DownloadHistory.seasons == season).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, + DownloadHistory.seasons == season).order_by( DownloadHistory.id.desc()).all() if tmdbid and season and episode: - return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, - DownloadHistory.seasons == season, - DownloadHistory.episodes == episode).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, + DownloadHistory.seasons == season, + DownloadHistory.episodes == episode).order_by( DownloadHistory.id.desc()).all() # 电视剧所有季集|电影 if not season and not episode: - return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, - DownloadHistory.title == title, - DownloadHistory.year == year).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype, + DownloadHistory.title == title, + DownloadHistory.year == year).order_by( DownloadHistory.id.desc()).all() # 电视剧某季 if season and not episode: - return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, - DownloadHistory.title == title, - DownloadHistory.year == year, - DownloadHistory.seasons == season).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype, + DownloadHistory.title == title, + DownloadHistory.year == year, + DownloadHistory.seasons == season).order_by( DownloadHistory.id.desc()).all() # 电视剧某季某集 if season and episode: - return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, - DownloadHistory.title == title, - DownloadHistory.year == year, - DownloadHistory.seasons == season, - DownloadHistory.episodes == episode).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype, + DownloadHistory.title == title, + DownloadHistory.year == year, + DownloadHistory.seasons == season, + DownloadHistory.episodes == episode).order_by( DownloadHistory.id.desc()).all() + if result: + return list(result) + @staticmethod + @db_query def list_by_user_date(db: Session, date: str, userid: str = None): """ 查询某用户某时间之后的下载历史 """ if userid: - return db.query(DownloadHistory).filter(DownloadHistory.date < date, - DownloadHistory.userid == userid).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.date < date, + DownloadHistory.userid == userid).order_by( DownloadHistory.id.desc()).all() else: - return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by( DownloadHistory.id.desc()).all() + return list(result) class DownloadFiles(Base): @@ -131,24 +143,30 @@ class DownloadFiles(Base): state = Column(Integer, nullable=False, default=1) @staticmethod + @db_query def get_by_hash(db: Session, download_hash: str, state: int = None): if state: - return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash, - DownloadFiles.state == state).all() + result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash, + DownloadFiles.state == state).all() else: - return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all() + result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all() + + return list(result) @staticmethod + @db_query def get_by_fullpath(db: Session, fullpath: str): return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by( DownloadFiles.id.desc()).first() @staticmethod + @db_query def get_by_savepath(db: Session, savepath: str): - return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() + result = db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() + return list(result) @staticmethod - @db_persist + @db_update def delete_by_fullpath(db: Session, fullpath: str): db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath, DownloadFiles.state == 1).update( diff --git a/app/db/models/mediaserver.py b/app/db/models/mediaserver.py index ee0e06e1..82164daa 100644 --- a/app/db/models/mediaserver.py +++ b/app/db/models/mediaserver.py @@ -3,7 +3,8 @@ from datetime import datetime from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class MediaServerItem(Base): @@ -41,20 +42,23 @@ class MediaServerItem(Base): lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) @staticmethod + @db_query def get_by_itemid(db: Session, item_id: str): return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first() @staticmethod - @db_persist + @db_update def empty(db: Session, server: str): db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() @staticmethod + @db_query def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid, MediaServerItem.item_type == mtype).first() @staticmethod + @db_query def exists_by_title(db: Session, title: str, mtype: str, year: str): return db.query(MediaServerItem).filter(MediaServerItem.title == title, MediaServerItem.item_type == mtype, diff --git a/app/db/models/plugin.py b/app/db/models/plugin.py index d1936a84..f6346728 100644 --- a/app/db/models/plugin.py +++ b/app/db/models/plugin.py @@ -1,7 +1,8 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class PluginData(Base): @@ -14,18 +15,23 @@ class PluginData(Base): value = Column(String) @staticmethod + @db_query def get_plugin_data(db: Session, plugin_id: str): - return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() + result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() + return list(result) @staticmethod + @db_query def get_plugin_data_by_key(db: Session, plugin_id: str, key: str): return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first() @staticmethod - @db_persist + @db_update def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() @staticmethod + @db_query def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): - return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() + result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() + return list(result) diff --git a/app/db/models/site.py b/app/db/models/site.py index ced51120..e79fe413 100644 --- a/app/db/models/site.py +++ b/app/db/models/site.py @@ -3,7 +3,8 @@ from datetime import datetime from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class Site(Base): @@ -47,18 +48,23 @@ class Site(Base): lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) @staticmethod + @db_query def get_by_domain(db: Session, domain: str): return db.query(Site).filter(Site.domain == domain).first() @staticmethod + @db_query def get_actives(db: Session): - return db.query(Site).filter(Site.is_active == 1).all() + result = db.query(Site).filter(Site.is_active == 1).all() + return list(result) @staticmethod + @db_query def list_order_by_pri(db: Session): - return db.query(Site).order_by(Site.pri).all() + result = db.query(Site).order_by(Site.pri).all() + return list(result) @staticmethod - @db_persist + @db_update def reset(db: Session): db.query(Site).delete() diff --git a/app/db/models/siteicon.py b/app/db/models/siteicon.py index ef4ca692..787b86c4 100644 --- a/app/db/models/siteicon.py +++ b/app/db/models/siteicon.py @@ -1,6 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session +from app.db import db_query from app.db.models import Base @@ -19,5 +20,6 @@ class SiteIcon(Base): base64 = Column(String) @staticmethod + @db_query def get_by_domain(db: Session, domain: str): return db.query(SiteIcon).filter(SiteIcon.domain == domain).first() diff --git a/app/db/models/subscribe.py b/app/db/models/subscribe.py index 92cfdf3d..efec6262 100644 --- a/app/db/models/subscribe.py +++ b/app/db/models/subscribe.py @@ -1,6 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session +from app.db import db_update, db_query from app.db.models import Base @@ -67,6 +68,7 @@ class Subscribe(Base): current_priority = Column(Integer) @staticmethod + @db_query def exists(db: Session, tmdbid: int, season: int = None): if season: return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, @@ -74,30 +76,39 @@ class Subscribe(Base): return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first() @staticmethod + @db_query def get_by_state(db: Session, state: str): - return db.query(Subscribe).filter(Subscribe.state == state).all() + result = db.query(Subscribe).filter(Subscribe.state == state).all() + return list(result) @staticmethod + @db_query def get_by_tmdbid(db: Session, tmdbid: int, season: int = None): if season: - return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, - Subscribe.season == season).all() - return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all() + result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, + Subscribe.season == season).all() + else: + result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all() + return list(result) @staticmethod + @db_query def get_by_title(db: Session, title: str): return db.query(Subscribe).filter(Subscribe.name == title).first() @staticmethod + @db_query def get_by_doubanid(db: Session, doubanid: str): return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() + @db_update def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int): subscrbies = self.get_by_tmdbid(db, tmdbid, season) for subscrbie in subscrbies: subscrbie.delete(db, subscrbie.id) return True + @db_update def delete_by_doubanid(self, db: Session, doubanid: str): subscribe = self.get_by_doubanid(db, doubanid) if subscribe: diff --git a/app/db/models/systemconfig.py b/app/db/models/systemconfig.py index 9fe299bb..5a304c01 100644 --- a/app/db/models/systemconfig.py +++ b/app/db/models/systemconfig.py @@ -1,6 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session +from app.db import db_update, db_query from app.db.models import Base @@ -15,9 +16,11 @@ class SystemConfig(Base): value = Column(String, nullable=True) @staticmethod + @db_query def get_by_key(db: Session, key: str): return db.query(SystemConfig).filter(SystemConfig.key == key).first() + @db_update def delete_by_key(self, db: Session, key: str): systemconfig = self.get_by_key(db, key) if systemconfig: diff --git a/app/db/models/transferhistory.py b/app/db/models/transferhistory.py index 31c55f5c..9ba25c78 100644 --- a/app/db/models/transferhistory.py +++ b/app/db/models/transferhistory.py @@ -3,7 +3,8 @@ import time from sqlalchemy import Column, Integer, String, Sequence, Boolean, func from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class TransferHistory(Base): @@ -47,29 +48,38 @@ class TransferHistory(Base): files = Column(String) @staticmethod + @db_query def list_by_title(db: Session, title: str, page: int = 1, count: int = 30): - return db.query(TransferHistory).filter(TransferHistory.title.like(f'%{title}%')).order_by( + result = db.query(TransferHistory).filter(TransferHistory.title.like(f'%{title}%')).order_by( TransferHistory.date.desc()).offset((page - 1) * count).limit( count).all() + return list(result) @staticmethod + @db_query def list_by_page(db: Session, page: int = 1, count: int = 30): - return db.query(TransferHistory).order_by(TransferHistory.date.desc()).offset((page - 1) * count).limit( + result = db.query(TransferHistory).order_by(TransferHistory.date.desc()).offset((page - 1) * count).limit( count).all() + return list(result) @staticmethod + @db_query def get_by_hash(db: Session, download_hash: str): return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).first() @staticmethod + @db_query def get_by_src(db: Session, src: str): return db.query(TransferHistory).filter(TransferHistory.src == src).first() @staticmethod + @db_query def list_by_hash(db: Session, download_hash: str): - return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all() + result = db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all() + return list(result) @staticmethod + @db_query def statistic(db: Session, days: int = 7): """ 统计最近days天的下载历史数量,按日期分组返回每日数量 @@ -78,74 +88,82 @@ class TransferHistory(Base): TransferHistory.id.label('id')).filter( TransferHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() - 86400 * days))).subquery() - return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all() + result = db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all() + return list(result) @staticmethod + @db_query def count(db: Session): return db.query(func.count(TransferHistory.id)).first()[0] @staticmethod + @db_query def count_by_title(db: Session, title: str): return db.query(func.count(TransferHistory.id)).filter(TransferHistory.title.like(f'%{title}%')).first()[0] @staticmethod + @db_query def list_by(db: Session, mtype: str = None, title: str = None, year: str = None, season: str = None, episode: str = None, tmdbid: int = None, dest: str = None): """ 据tmdbid、season、season_episode查询转移记录 tmdbid + mtype 或 title + year 必输 """ + result = None # TMDBID + 类型 if tmdbid and mtype: # 电视剧某季某集 if season and episode: - return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, - TransferHistory.type == mtype, - TransferHistory.seasons == season, - TransferHistory.episodes == episode, - TransferHistory.dest == dest).all() + result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, + TransferHistory.type == mtype, + TransferHistory.seasons == season, + TransferHistory.episodes == episode, + TransferHistory.dest == dest).all() # 电视剧某季 elif season: - return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, - TransferHistory.type == mtype, - TransferHistory.seasons == season).all() + result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, + TransferHistory.type == mtype, + TransferHistory.seasons == season).all() else: if dest: # 电影 - return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, - TransferHistory.type == mtype, - TransferHistory.dest == dest).all() + result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, + TransferHistory.type == mtype, + TransferHistory.dest == dest).all() else: # 电视剧所有季集 - return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, - TransferHistory.type == mtype).all() + result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, + TransferHistory.type == mtype).all() # 标题 + 年份 elif title and year: # 电视剧某季某集 if season and episode: - return db.query(TransferHistory).filter(TransferHistory.title == title, - TransferHistory.year == year, - TransferHistory.seasons == season, - TransferHistory.episodes == episode, - TransferHistory.dest == dest).all() + result = db.query(TransferHistory).filter(TransferHistory.title == title, + TransferHistory.year == year, + TransferHistory.seasons == season, + TransferHistory.episodes == episode, + TransferHistory.dest == dest).all() # 电视剧某季 elif season: - return db.query(TransferHistory).filter(TransferHistory.title == title, - TransferHistory.year == year, - TransferHistory.seasons == season).all() + result = db.query(TransferHistory).filter(TransferHistory.title == title, + TransferHistory.year == year, + TransferHistory.seasons == season).all() else: if dest: # 电影 - return db.query(TransferHistory).filter(TransferHistory.title == title, - TransferHistory.year == year, - TransferHistory.dest == dest).all() + result = db.query(TransferHistory).filter(TransferHistory.title == title, + TransferHistory.year == year, + TransferHistory.dest == dest).all() else: # 电视剧所有季集 - return db.query(TransferHistory).filter(TransferHistory.title == title, - TransferHistory.year == year).all() + result = db.query(TransferHistory).filter(TransferHistory.title == title, + TransferHistory.year == year).all() + if result: + return list(result) return [] @staticmethod + @db_query def get_by_type_tmdbid(db: Session, mtype: str = None, tmdbid: int = None): """ 据tmdbid、type查询转移记录 @@ -154,7 +172,7 @@ class TransferHistory(Base): TransferHistory.type == mtype).first() @staticmethod - @db_persist + @db_update def update_download_hash(db: Session, historyid: int = None, download_hash: str = None): db.query(TransferHistory).filter(TransferHistory.id == historyid).update( { diff --git a/app/db/models/user.py b/app/db/models/user.py index 94676f5b..b058c48a 100644 --- a/app/db/models/user.py +++ b/app/db/models/user.py @@ -2,6 +2,7 @@ from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session from app.core.security import verify_password +from app.db import db_update, db_query from app.db.models import Base @@ -25,6 +26,7 @@ class User(Base): avatar = Column(String) @staticmethod + @db_query def authenticate(db: Session, name: str, password: str): user = db.query(User).filter(User.name == name).first() if not user: @@ -34,9 +36,11 @@ class User(Base): return user @staticmethod + @db_query def get_by_name(db: Session, name: str): return db.query(User).filter(User.name == name).first() + @db_update def delete_by_name(self, db: Session, name: str): user = self.get_by_name(db, name) if user: From e5718a50b2f2ad0d094493b68360f956bdc2de37 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 19 Oct 2023 17:15:46 +0800 Subject: [PATCH 2/6] fix bug --- app/db/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index e5e60eec..c0e0392f 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -13,7 +13,6 @@ class Base: @db_update def create(self, db: Session) -> Self: db.add(self) - db.refresh(self) return self @classmethod From b33e77702858d33001a42e279fdeea0aab0095ca Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 19 Oct 2023 17:39:15 +0800 Subject: [PATCH 3/6] fix bug --- app/db/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/app/db/__init__.py b/app/db/__init__.py index 3f9a436c..19649177 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional, Generator from sqlalchemy import create_engine, QueuePool from sqlalchemy.orm import sessionmaker, Session, scoped_session @@ -22,7 +22,7 @@ SessionFactory = sessionmaker(bind=Engine) ScopedSession = scoped_session(SessionFactory) -def get_db(): +def get_db() -> Generator: """ 获取数据库会话,用于WEB请求 :return: Session @@ -36,7 +36,7 @@ def get_db(): db.close() -def get_args_db(args: tuple, kwargs: dict): +def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]: """ 从参数中获取数据库Session对象 """ @@ -58,7 +58,7 @@ def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict] """ 更新参数中的数据库Session对象,关键字传参时更新db的值,否则更新第1或第2个参数 """ - if kwargs: + if kwargs and 'db' in kwargs: kwargs['db'] = db elif args: if args[0] is None: From 4739d43c45699b112a02d4f6e361e1928d7a5973 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 19 Oct 2023 17:55:40 +0800 Subject: [PATCH 4/6] v1.3.5 --- version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.py b/version.py index 29dff19e..c130a7d2 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -APP_VERSION = 'v1.3.4' +APP_VERSION = 'v1.3.5' From c932d2b7f0ee0ab8883de0c54e8680189d89165c Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 19 Oct 2023 18:08:56 +0800 Subject: [PATCH 5/6] fix bug --- app/api/endpoints/subscribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/api/endpoints/subscribe.py b/app/api/endpoints/subscribe.py index 63889972..afba458f 100644 --- a/app/api/endpoints/subscribe.py +++ b/app/api/endpoints/subscribe.py @@ -195,7 +195,7 @@ def read_subscribe( 根据订阅编号查询订阅信息 """ subscribe = Subscribe.get(db, subscribe_id) - if subscribe.sites: + if subscribe and subscribe.sites: subscribe.sites = json.loads(subscribe.sites) return subscribe From 7d9a3d39b32033747a222eb2485edc8e7a382c8d Mon Sep 17 00:00:00 2001 From: thsrite Date: Thu, 19 Oct 2023 18:09:49 +0800 Subject: [PATCH 6/6] =?UTF-8?q?fix=20=E6=9B=B4=E6=96=B0=E9=80=9A=E7=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../moviepilotupdatenotify/__init__.py | 176 +++++++++--------- 1 file changed, 90 insertions(+), 86 deletions(-) diff --git a/app/plugins/moviepilotupdatenotify/__init__.py b/app/plugins/moviepilotupdatenotify/__init__.py index 8336e0c1..b088c77f 100644 --- a/app/plugins/moviepilotupdatenotify/__init__.py +++ b/app/plugins/moviepilotupdatenotify/__init__.py @@ -1,3 +1,5 @@ +import datetime + from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.cron import CronTrigger @@ -81,12 +83,14 @@ class MoviePilotUpdateNotify(_PluginBase): # 本地版本 local_version = SystemChain().get_local_version() - if release_version == local_version: + if local_version and release_version <= local_version: logger.info(f"当前版本:{local_version} 远程版本:{release_version} 停止运行") return # 推送更新消息 if self._notify: + # 将时间字符串转为datetime对象 + update_time = datetime.datetime.strptime(update_time, "%Y-%m-%dT%H:%M:%SZ").strftime("%Y-%m-%d %H:%M:%S") self.post_message( mtype=NotificationType.SiteMessage, title="【MoviePilot更新通知】", @@ -131,91 +135,91 @@ class MoviePilotUpdateNotify(_PluginBase): 拼装插件配置页面,需要返回两块数据:1、页面配置;2、数据结构 """ return [ - { - 'component': 'VForm', - 'content': [ - { - 'component': 'VRow', - 'content': [ - { - 'component': 'VCol', - 'props': { - 'cols': 12, - 'md': 4 - }, - 'content': [ - { - 'component': 'VSwitch', - 'props': { - 'model': 'enabled', - 'label': '启用插件', - } - } - ] - }, - { - 'component': 'VCol', - 'props': { - 'cols': 12, - 'md': 4 - }, - 'content': [ - { - 'component': 'VSwitch', - 'props': { - 'model': 'update', - 'label': '自动更新', - } - } - ] - }, - { - 'component': 'VCol', - 'props': { - 'cols': 12, - 'md': 4 - }, - 'content': [ - { - 'component': 'VSwitch', - 'props': { - 'model': 'notify', - 'label': '发送通知', - } - } - ] - } - ] - }, - { - 'component': 'VRow', - 'content': [ - { - 'component': 'VCol', - 'props': { - 'cols': 12, - }, - 'content': [ - { - 'component': 'VTextField', - 'props': { - 'model': 'cron', - 'label': '检查周期', - 'placeholder': '5位cron表达式' - } - } - ] - }, - ] - } - ] - } - ], { - "enabled": False, - "update": False, - "notify": False, - "cron": "0 9 * * *" - } + { + 'component': 'VForm', + 'content': [ + { + 'component': 'VRow', + 'content': [ + { + 'component': 'VCol', + 'props': { + 'cols': 12, + 'md': 4 + }, + 'content': [ + { + 'component': 'VSwitch', + 'props': { + 'model': 'enabled', + 'label': '启用插件', + } + } + ] + }, + { + 'component': 'VCol', + 'props': { + 'cols': 12, + 'md': 4 + }, + 'content': [ + { + 'component': 'VSwitch', + 'props': { + 'model': 'update', + 'label': '自动更新', + } + } + ] + }, + { + 'component': 'VCol', + 'props': { + 'cols': 12, + 'md': 4 + }, + 'content': [ + { + 'component': 'VSwitch', + 'props': { + 'model': 'notify', + 'label': '发送通知', + } + } + ] + } + ] + }, + { + 'component': 'VRow', + 'content': [ + { + 'component': 'VCol', + 'props': { + 'cols': 12, + }, + 'content': [ + { + 'component': 'VTextField', + 'props': { + 'model': 'cron', + 'label': '检查周期', + 'placeholder': '5位cron表达式' + } + } + ] + }, + ] + } + ] + } + ], { + "enabled": False, + "update": False, + "notify": False, + "cron": "0 9 * * *" + } def get_page(self) -> List[dict]: pass