From a911bab7b0f9eb5ce94d7ffe6aced803bc206b09 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 19 Oct 2023 16:58:38 +0800 Subject: [PATCH] fix db session --- app/db/__init__.py | 112 ++++++++++++++++++++++++++++--- app/db/models/__init__.py | 40 +++-------- app/db/models/downloadhistory.py | 74 ++++++++++++-------- app/db/models/mediaserver.py | 8 ++- app/db/models/plugin.py | 14 ++-- app/db/models/site.py | 14 ++-- app/db/models/siteicon.py | 2 + app/db/models/subscribe.py | 19 ++++-- app/db/models/systemconfig.py | 3 + app/db/models/transferhistory.py | 82 +++++++++++++--------- app/db/models/user.py | 4 ++ 11 files changed, 258 insertions(+), 114 deletions(-) diff --git a/app/db/__init__.py b/app/db/__init__.py index dda252a2..3f9a436c 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,13 +1,10 @@ -import threading +from typing import Tuple from sqlalchemy import create_engine, QueuePool from sqlalchemy.orm import sessionmaker, Session, scoped_session from app.core.config import settings -# 数据库锁 -DBLock = threading.Lock() - # 数据库引擎 Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db", pool_pre_ping=True, @@ -27,7 +24,7 @@ ScopedSession = scoped_session(SessionFactory) def get_db(): """ - 获取数据库会话 + 获取数据库会话,用于WEB请求 :return: Session """ db = None @@ -39,6 +36,105 @@ def get_db(): db.close() +def get_args_db(args: tuple, kwargs: dict): + """ + 从参数中获取数据库Session对象 + """ + db = None + if args: + for arg in args: + if isinstance(arg, Session): + db = arg + break + if kwargs: + for key, value in kwargs.items(): + if isinstance(value, Session): + db = value + break + return db + + +def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]: + """ + 更新参数中的数据库Session对象,关键字传参时更新db的值,否则更新第1或第2个参数 + """ + if kwargs: + kwargs['db'] = db + elif args: + if args[0] is None: + args = (db, *args[1:]) + else: + args = (args[0], db, *args[2:]) + return args, kwargs + + +def db_update(func): + """ + 数据库更新类操作装饰器,第一个参数必须是数据库会话或存在db参数 + """ + + def wrapper(*args, **kwargs): + # 是否关闭数据库会话 + _close_db = False + # 从参数中获取数据库会话 + db = get_args_db(args, kwargs) + if not db: + # 如果没有获取到数据库会话,创建一个 + db = ScopedSession() + # 标记需要关闭数据库会话 + _close_db = True + # 更新参数中的数据库会话 + args, kwargs = update_args_db(args, kwargs, db) + try: + # 执行函数 + result = func(*args, **kwargs) + # 提交事务 + db.commit() + except Exception as err: + # 回滚事务 + db.rollback() + raise err + finally: + # 关闭数据库会话 + if _close_db: + db.close() + return result + + return wrapper + + +def db_query(func): + """ + 数据库查询操作装饰器,第一个参数必须是数据库会话或存在db参数 + 注意:db.query列表数据时,需要转换为list返回 + """ + + def wrapper(*args, **kwargs): + # 是否关闭数据库会话 + _close_db = False + # 从参数中获取数据库会话 + db = get_args_db(args, kwargs) + if not db: + # 如果没有获取到数据库会话,创建一个 + db = ScopedSession() + # 标记需要关闭数据库会话 + _close_db = True + # 更新参数中的数据库会话 + args, kwargs = update_args_db(args, kwargs, db) + try: + # 执行函数 + result = func(*args, **kwargs) + except Exception as err: + raise err + finally: + # 关闭数据库会话 + if _close_db: + db.close() + return result + + return wrapper + + class DbOper: """ 数据库操作基类 @@ -46,8 +142,4 @@ class DbOper: _db: Session = None def __init__(self, db: Session = None): - if db: - self._db = db - else: - with DBLock: - self._db = ScopedSession() + self._db = db diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index d06ce301..e5e60eec 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -2,31 +2,7 @@ from typing import Any, Self, List from sqlalchemy.orm import as_declarative, declared_attr, Session -from app.db import DBLock - - -def db_persist(func): - """ - 数据库操作装饰器,获取第一个输入参数db,执行数据库操作后提交 - """ - - def wrapper(*args, **kwargs): - with DBLock: - db: Session = kwargs.get("db") - if not db: - for arg in args: - if isinstance(arg, Session): - db = arg - break - try: - result = func(*args, **kwargs) - db.commit() - except Exception as err: - db.rollback() - raise err - return result - - return wrapper +from app.db import db_update, db_query @as_declarative() @@ -34,34 +10,38 @@ class Base: id: Any __name__: str - @db_persist + @db_update def create(self, db: Session) -> Self: db.add(self) + db.refresh(self) return self @classmethod + @db_query def get(cls, db: Session, rid: int) -> Self: return db.query(cls).filter(cls.id == rid).first() - @db_persist + @db_update def update(self, db: Session, payload: dict): payload = {k: v for k, v in payload.items() if v is not None} for key, value in payload.items(): setattr(self, key, value) @classmethod - @db_persist + @db_update def delete(cls, db: Session, rid): db.query(cls).filter(cls.id == rid).delete() @classmethod - @db_persist + @db_update def truncate(cls, db: Session): db.query(cls).delete() @classmethod + @db_query def list(cls, db: Session) -> List[Self]: - return db.query(cls).all() + result = db.query(cls).all() + return list(result) def to_dict(self): return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} diff --git a/app/db/models/downloadhistory.py b/app/db/models/downloadhistory.py index e249fb25..58358c03 100644 --- a/app/db/models/downloadhistory.py +++ b/app/db/models/downloadhistory.py @@ -1,7 +1,8 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class DownloadHistory(Base): @@ -45,69 +46,80 @@ class DownloadHistory(Base): note = Column(String) @staticmethod + @db_query def get_by_hash(db: Session, download_hash: str): return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).first() @staticmethod + @db_query def list_by_page(db: Session, page: int = 1, count: int = 30): - return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all() + result = db.query(DownloadHistory).offset((page - 1) * count).limit(count).all() + return list(result) @staticmethod + @db_query def get_by_path(db: Session, path: str): return db.query(DownloadHistory).filter(DownloadHistory.path == path).first() @staticmethod + @db_query def get_last_by(db: Session, mtype: str = None, title: str = None, year: int = None, season: str = None, episode: str = None, tmdbid: int = None): """ 据tmdbid、season、season_episode查询转移记录 """ + result = None if tmdbid and not season and not episode: - return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).order_by( DownloadHistory.id.desc()).all() if tmdbid and season and not episode: - return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, - DownloadHistory.seasons == season).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, + DownloadHistory.seasons == season).order_by( DownloadHistory.id.desc()).all() if tmdbid and season and episode: - return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, - DownloadHistory.seasons == season, - DownloadHistory.episodes == episode).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, + DownloadHistory.seasons == season, + DownloadHistory.episodes == episode).order_by( DownloadHistory.id.desc()).all() # 电视剧所有季集|电影 if not season and not episode: - return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, - DownloadHistory.title == title, - DownloadHistory.year == year).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype, + DownloadHistory.title == title, + DownloadHistory.year == year).order_by( DownloadHistory.id.desc()).all() # 电视剧某季 if season and not episode: - return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, - DownloadHistory.title == title, - DownloadHistory.year == year, - DownloadHistory.seasons == season).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype, + DownloadHistory.title == title, + DownloadHistory.year == year, + DownloadHistory.seasons == season).order_by( DownloadHistory.id.desc()).all() # 电视剧某季某集 if season and episode: - return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, - DownloadHistory.title == title, - DownloadHistory.year == year, - DownloadHistory.seasons == season, - DownloadHistory.episodes == episode).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype, + DownloadHistory.title == title, + DownloadHistory.year == year, + DownloadHistory.seasons == season, + DownloadHistory.episodes == episode).order_by( DownloadHistory.id.desc()).all() + if result: + return list(result) + @staticmethod + @db_query def list_by_user_date(db: Session, date: str, userid: str = None): """ 查询某用户某时间之后的下载历史 """ if userid: - return db.query(DownloadHistory).filter(DownloadHistory.date < date, - DownloadHistory.userid == userid).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.date < date, + DownloadHistory.userid == userid).order_by( DownloadHistory.id.desc()).all() else: - return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by( + result = db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by( DownloadHistory.id.desc()).all() + return list(result) class DownloadFiles(Base): @@ -131,24 +143,30 @@ class DownloadFiles(Base): state = Column(Integer, nullable=False, default=1) @staticmethod + @db_query def get_by_hash(db: Session, download_hash: str, state: int = None): if state: - return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash, - DownloadFiles.state == state).all() + result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash, + DownloadFiles.state == state).all() else: - return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all() + result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all() + + return list(result) @staticmethod + @db_query def get_by_fullpath(db: Session, fullpath: str): return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by( DownloadFiles.id.desc()).first() @staticmethod + @db_query def get_by_savepath(db: Session, savepath: str): - return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() + result = db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() + return list(result) @staticmethod - @db_persist + @db_update def delete_by_fullpath(db: Session, fullpath: str): db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath, DownloadFiles.state == 1).update( diff --git a/app/db/models/mediaserver.py b/app/db/models/mediaserver.py index ee0e06e1..82164daa 100644 --- a/app/db/models/mediaserver.py +++ b/app/db/models/mediaserver.py @@ -3,7 +3,8 @@ from datetime import datetime from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class MediaServerItem(Base): @@ -41,20 +42,23 @@ class MediaServerItem(Base): lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) @staticmethod + @db_query def get_by_itemid(db: Session, item_id: str): return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first() @staticmethod - @db_persist + @db_update def empty(db: Session, server: str): db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() @staticmethod + @db_query def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid, MediaServerItem.item_type == mtype).first() @staticmethod + @db_query def exists_by_title(db: Session, title: str, mtype: str, year: str): return db.query(MediaServerItem).filter(MediaServerItem.title == title, MediaServerItem.item_type == mtype, diff --git a/app/db/models/plugin.py b/app/db/models/plugin.py index d1936a84..f6346728 100644 --- a/app/db/models/plugin.py +++ b/app/db/models/plugin.py @@ -1,7 +1,8 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class PluginData(Base): @@ -14,18 +15,23 @@ class PluginData(Base): value = Column(String) @staticmethod + @db_query def get_plugin_data(db: Session, plugin_id: str): - return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() + result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() + return list(result) @staticmethod + @db_query def get_plugin_data_by_key(db: Session, plugin_id: str, key: str): return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first() @staticmethod - @db_persist + @db_update def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() @staticmethod + @db_query def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): - return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() + result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() + return list(result) diff --git a/app/db/models/site.py b/app/db/models/site.py index ced51120..e79fe413 100644 --- a/app/db/models/site.py +++ b/app/db/models/site.py @@ -3,7 +3,8 @@ from datetime import datetime from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class Site(Base): @@ -47,18 +48,23 @@ class Site(Base): lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) @staticmethod + @db_query def get_by_domain(db: Session, domain: str): return db.query(Site).filter(Site.domain == domain).first() @staticmethod + @db_query def get_actives(db: Session): - return db.query(Site).filter(Site.is_active == 1).all() + result = db.query(Site).filter(Site.is_active == 1).all() + return list(result) @staticmethod + @db_query def list_order_by_pri(db: Session): - return db.query(Site).order_by(Site.pri).all() + result = db.query(Site).order_by(Site.pri).all() + return list(result) @staticmethod - @db_persist + @db_update def reset(db: Session): db.query(Site).delete() diff --git a/app/db/models/siteicon.py b/app/db/models/siteicon.py index ef4ca692..787b86c4 100644 --- a/app/db/models/siteicon.py +++ b/app/db/models/siteicon.py @@ -1,6 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session +from app.db import db_query from app.db.models import Base @@ -19,5 +20,6 @@ class SiteIcon(Base): base64 = Column(String) @staticmethod + @db_query def get_by_domain(db: Session, domain: str): return db.query(SiteIcon).filter(SiteIcon.domain == domain).first() diff --git a/app/db/models/subscribe.py b/app/db/models/subscribe.py index 92cfdf3d..efec6262 100644 --- a/app/db/models/subscribe.py +++ b/app/db/models/subscribe.py @@ -1,6 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session +from app.db import db_update, db_query from app.db.models import Base @@ -67,6 +68,7 @@ class Subscribe(Base): current_priority = Column(Integer) @staticmethod + @db_query def exists(db: Session, tmdbid: int, season: int = None): if season: return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, @@ -74,30 +76,39 @@ class Subscribe(Base): return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first() @staticmethod + @db_query def get_by_state(db: Session, state: str): - return db.query(Subscribe).filter(Subscribe.state == state).all() + result = db.query(Subscribe).filter(Subscribe.state == state).all() + return list(result) @staticmethod + @db_query def get_by_tmdbid(db: Session, tmdbid: int, season: int = None): if season: - return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, - Subscribe.season == season).all() - return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all() + result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, + Subscribe.season == season).all() + else: + result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all() + return list(result) @staticmethod + @db_query def get_by_title(db: Session, title: str): return db.query(Subscribe).filter(Subscribe.name == title).first() @staticmethod + @db_query def get_by_doubanid(db: Session, doubanid: str): return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() + @db_update def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int): subscrbies = self.get_by_tmdbid(db, tmdbid, season) for subscrbie in subscrbies: subscrbie.delete(db, subscrbie.id) return True + @db_update def delete_by_doubanid(self, db: Session, doubanid: str): subscribe = self.get_by_doubanid(db, doubanid) if subscribe: diff --git a/app/db/models/systemconfig.py b/app/db/models/systemconfig.py index 9fe299bb..5a304c01 100644 --- a/app/db/models/systemconfig.py +++ b/app/db/models/systemconfig.py @@ -1,6 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session +from app.db import db_update, db_query from app.db.models import Base @@ -15,9 +16,11 @@ class SystemConfig(Base): value = Column(String, nullable=True) @staticmethod + @db_query def get_by_key(db: Session, key: str): return db.query(SystemConfig).filter(SystemConfig.key == key).first() + @db_update def delete_by_key(self, db: Session, key: str): systemconfig = self.get_by_key(db, key) if systemconfig: diff --git a/app/db/models/transferhistory.py b/app/db/models/transferhistory.py index 31c55f5c..9ba25c78 100644 --- a/app/db/models/transferhistory.py +++ b/app/db/models/transferhistory.py @@ -3,7 +3,8 @@ import time from sqlalchemy import Column, Integer, String, Sequence, Boolean, func from sqlalchemy.orm import Session -from app.db.models import Base, db_persist +from app.db import db_query +from app.db.models import Base, db_update class TransferHistory(Base): @@ -47,29 +48,38 @@ class TransferHistory(Base): files = Column(String) @staticmethod + @db_query def list_by_title(db: Session, title: str, page: int = 1, count: int = 30): - return db.query(TransferHistory).filter(TransferHistory.title.like(f'%{title}%')).order_by( + result = db.query(TransferHistory).filter(TransferHistory.title.like(f'%{title}%')).order_by( TransferHistory.date.desc()).offset((page - 1) * count).limit( count).all() + return list(result) @staticmethod + @db_query def list_by_page(db: Session, page: int = 1, count: int = 30): - return db.query(TransferHistory).order_by(TransferHistory.date.desc()).offset((page - 1) * count).limit( + result = db.query(TransferHistory).order_by(TransferHistory.date.desc()).offset((page - 1) * count).limit( count).all() + return list(result) @staticmethod + @db_query def get_by_hash(db: Session, download_hash: str): return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).first() @staticmethod + @db_query def get_by_src(db: Session, src: str): return db.query(TransferHistory).filter(TransferHistory.src == src).first() @staticmethod + @db_query def list_by_hash(db: Session, download_hash: str): - return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all() + result = db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all() + return list(result) @staticmethod + @db_query def statistic(db: Session, days: int = 7): """ 统计最近days天的下载历史数量,按日期分组返回每日数量 @@ -78,74 +88,82 @@ class TransferHistory(Base): TransferHistory.id.label('id')).filter( TransferHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() - 86400 * days))).subquery() - return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all() + result = db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all() + return list(result) @staticmethod + @db_query def count(db: Session): return db.query(func.count(TransferHistory.id)).first()[0] @staticmethod + @db_query def count_by_title(db: Session, title: str): return db.query(func.count(TransferHistory.id)).filter(TransferHistory.title.like(f'%{title}%')).first()[0] @staticmethod + @db_query def list_by(db: Session, mtype: str = None, title: str = None, year: str = None, season: str = None, episode: str = None, tmdbid: int = None, dest: str = None): """ 据tmdbid、season、season_episode查询转移记录 tmdbid + mtype 或 title + year 必输 """ + result = None # TMDBID + 类型 if tmdbid and mtype: # 电视剧某季某集 if season and episode: - return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, - TransferHistory.type == mtype, - TransferHistory.seasons == season, - TransferHistory.episodes == episode, - TransferHistory.dest == dest).all() + result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, + TransferHistory.type == mtype, + TransferHistory.seasons == season, + TransferHistory.episodes == episode, + TransferHistory.dest == dest).all() # 电视剧某季 elif season: - return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, - TransferHistory.type == mtype, - TransferHistory.seasons == season).all() + result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, + TransferHistory.type == mtype, + TransferHistory.seasons == season).all() else: if dest: # 电影 - return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, - TransferHistory.type == mtype, - TransferHistory.dest == dest).all() + result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, + TransferHistory.type == mtype, + TransferHistory.dest == dest).all() else: # 电视剧所有季集 - return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, - TransferHistory.type == mtype).all() + result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, + TransferHistory.type == mtype).all() # 标题 + 年份 elif title and year: # 电视剧某季某集 if season and episode: - return db.query(TransferHistory).filter(TransferHistory.title == title, - TransferHistory.year == year, - TransferHistory.seasons == season, - TransferHistory.episodes == episode, - TransferHistory.dest == dest).all() + result = db.query(TransferHistory).filter(TransferHistory.title == title, + TransferHistory.year == year, + TransferHistory.seasons == season, + TransferHistory.episodes == episode, + TransferHistory.dest == dest).all() # 电视剧某季 elif season: - return db.query(TransferHistory).filter(TransferHistory.title == title, - TransferHistory.year == year, - TransferHistory.seasons == season).all() + result = db.query(TransferHistory).filter(TransferHistory.title == title, + TransferHistory.year == year, + TransferHistory.seasons == season).all() else: if dest: # 电影 - return db.query(TransferHistory).filter(TransferHistory.title == title, - TransferHistory.year == year, - TransferHistory.dest == dest).all() + result = db.query(TransferHistory).filter(TransferHistory.title == title, + TransferHistory.year == year, + TransferHistory.dest == dest).all() else: # 电视剧所有季集 - return db.query(TransferHistory).filter(TransferHistory.title == title, - TransferHistory.year == year).all() + result = db.query(TransferHistory).filter(TransferHistory.title == title, + TransferHistory.year == year).all() + if result: + return list(result) return [] @staticmethod + @db_query def get_by_type_tmdbid(db: Session, mtype: str = None, tmdbid: int = None): """ 据tmdbid、type查询转移记录 @@ -154,7 +172,7 @@ class TransferHistory(Base): TransferHistory.type == mtype).first() @staticmethod - @db_persist + @db_update def update_download_hash(db: Session, historyid: int = None, download_hash: str = None): db.query(TransferHistory).filter(TransferHistory.id == historyid).update( { diff --git a/app/db/models/user.py b/app/db/models/user.py index 94676f5b..b058c48a 100644 --- a/app/db/models/user.py +++ b/app/db/models/user.py @@ -2,6 +2,7 @@ from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session from app.core.security import verify_password +from app.db import db_update, db_query from app.db.models import Base @@ -25,6 +26,7 @@ class User(Base): avatar = Column(String) @staticmethod + @db_query def authenticate(db: Session, name: str, password: str): user = db.query(User).filter(User.name == name).first() if not user: @@ -34,9 +36,11 @@ class User(Base): return user @staticmethod + @db_query def get_by_name(db: Session, name: str): return db.query(User).filter(User.name == name).first() + @db_update def delete_by_name(self, db: Session, name: str): user = self.get_by_name(db, name) if user: