fix db session

This commit is contained in:
jxxghp 2023-10-19 16:58:38 +08:00
parent 21908bdc6f
commit a911bab7b0
11 changed files with 258 additions and 114 deletions

View File

@ -1,13 +1,10 @@
import threading from typing import Tuple
from sqlalchemy import create_engine, QueuePool from sqlalchemy import create_engine, QueuePool
from sqlalchemy.orm import sessionmaker, Session, scoped_session from sqlalchemy.orm import sessionmaker, Session, scoped_session
from app.core.config import settings from app.core.config import settings
# 数据库锁
DBLock = threading.Lock()
# 数据库引擎 # 数据库引擎
Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db", Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db",
pool_pre_ping=True, pool_pre_ping=True,
@ -27,7 +24,7 @@ ScopedSession = scoped_session(SessionFactory)
def get_db(): def get_db():
""" """
获取数据库会话 获取数据库会话用于WEB请求
:return: Session :return: Session
""" """
db = None db = None
@ -39,6 +36,105 @@ def get_db():
db.close() db.close()
def get_args_db(args: tuple, kwargs: dict):
"""
从参数中获取数据库Session对象
"""
db = None
if args:
for arg in args:
if isinstance(arg, Session):
db = arg
break
if kwargs:
for key, value in kwargs.items():
if isinstance(value, Session):
db = value
break
return db
def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
"""
更新参数中的数据库Session对象关键字传参时更新db的值否则更新第1或第2个参数
"""
if kwargs:
kwargs['db'] = db
elif args:
if args[0] is None:
args = (db, *args[1:])
else:
args = (args[0], db, *args[2:])
return args, kwargs
def db_update(func):
"""
数据库更新类操作装饰器第一个参数必须是数据库会话或存在db参数
"""
def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取数据库会话
db = get_args_db(args, kwargs)
if not db:
# 如果没有获取到数据库会话,创建一个
db = ScopedSession()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db)
try:
# 执行函数
result = func(*args, **kwargs)
# 提交事务
db.commit()
except Exception as err:
# 回滚事务
db.rollback()
raise err
finally:
# 关闭数据库会话
if _close_db:
db.close()
return result
return wrapper
def db_query(func):
"""
数据库查询操作装饰器第一个参数必须是数据库会话或存在db参数
注意db.query列表数据时需要转换为list返回
"""
def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取数据库会话
db = get_args_db(args, kwargs)
if not db:
# 如果没有获取到数据库会话,创建一个
db = ScopedSession()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db)
try:
# 执行函数
result = func(*args, **kwargs)
except Exception as err:
raise err
finally:
# 关闭数据库会话
if _close_db:
db.close()
return result
return wrapper
class DbOper: class DbOper:
""" """
数据库操作基类 数据库操作基类
@ -46,8 +142,4 @@ class DbOper:
_db: Session = None _db: Session = None
def __init__(self, db: Session = None): def __init__(self, db: Session = None):
if db: self._db = db
self._db = db
else:
with DBLock:
self._db = ScopedSession()

View File

@ -2,31 +2,7 @@ from typing import Any, Self, List
from sqlalchemy.orm import as_declarative, declared_attr, Session from sqlalchemy.orm import as_declarative, declared_attr, Session
from app.db import DBLock from app.db import db_update, db_query
def db_persist(func):
"""
数据库操作装饰器获取第一个输入参数db执行数据库操作后提交
"""
def wrapper(*args, **kwargs):
with DBLock:
db: Session = kwargs.get("db")
if not db:
for arg in args:
if isinstance(arg, Session):
db = arg
break
try:
result = func(*args, **kwargs)
db.commit()
except Exception as err:
db.rollback()
raise err
return result
return wrapper
@as_declarative() @as_declarative()
@ -34,34 +10,38 @@ class Base:
id: Any id: Any
__name__: str __name__: str
@db_persist @db_update
def create(self, db: Session) -> Self: def create(self, db: Session) -> Self:
db.add(self) db.add(self)
db.refresh(self)
return self return self
@classmethod @classmethod
@db_query
def get(cls, db: Session, rid: int) -> Self: def get(cls, db: Session, rid: int) -> Self:
return db.query(cls).filter(cls.id == rid).first() return db.query(cls).filter(cls.id == rid).first()
@db_persist @db_update
def update(self, db: Session, payload: dict): def update(self, db: Session, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None} payload = {k: v for k, v in payload.items() if v is not None}
for key, value in payload.items(): for key, value in payload.items():
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
@db_persist @db_update
def delete(cls, db: Session, rid): def delete(cls, db: Session, rid):
db.query(cls).filter(cls.id == rid).delete() db.query(cls).filter(cls.id == rid).delete()
@classmethod @classmethod
@db_persist @db_update
def truncate(cls, db: Session): def truncate(cls, db: Session):
db.query(cls).delete() db.query(cls).delete()
@classmethod @classmethod
@db_query
def list(cls, db: Session) -> List[Self]: def list(cls, db: Session) -> List[Self]:
return db.query(cls).all() result = db.query(cls).all()
return list(result)
def to_dict(self): def to_dict(self):
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} return {c.name: getattr(self, c.name, None) for c in self.__table__.columns}

View File

@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base, db_persist from app.db import db_query
from app.db.models import Base, db_update
class DownloadHistory(Base): class DownloadHistory(Base):
@ -45,69 +46,80 @@ class DownloadHistory(Base):
note = Column(String) note = Column(String)
@staticmethod @staticmethod
@db_query
def get_by_hash(db: Session, download_hash: str): def get_by_hash(db: Session, download_hash: str):
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).first() return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).first()
@staticmethod @staticmethod
@db_query
def list_by_page(db: Session, page: int = 1, count: int = 30): def list_by_page(db: Session, page: int = 1, count: int = 30):
return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all() result = db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod @staticmethod
@db_query
def get_by_path(db: Session, path: str): def get_by_path(db: Session, path: str):
return db.query(DownloadHistory).filter(DownloadHistory.path == path).first() return db.query(DownloadHistory).filter(DownloadHistory.path == path).first()
@staticmethod @staticmethod
@db_query
def get_last_by(db: Session, mtype: str = None, title: str = None, year: int = None, season: str = None, def get_last_by(db: Session, mtype: str = None, title: str = None, year: int = None, season: str = None,
episode: str = None, tmdbid: int = None): episode: str = None, tmdbid: int = None):
""" """
据tmdbidseasonseason_episode查询转移记录 据tmdbidseasonseason_episode查询转移记录
""" """
result = None
if tmdbid and not season and not episode: if tmdbid and not season and not episode:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).order_by( result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
if tmdbid and season and not episode: if tmdbid and season and not episode:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.seasons == season).order_by( DownloadHistory.seasons == season).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
if tmdbid and season and episode: if tmdbid and season and episode:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.seasons == season, DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by( DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
# 电视剧所有季集|电影 # 电视剧所有季集|电影
if not season and not episode: if not season and not episode:
return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype,
DownloadHistory.title == title, DownloadHistory.title == title,
DownloadHistory.year == year).order_by( DownloadHistory.year == year).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
# 电视剧某季 # 电视剧某季
if season and not episode: if season and not episode:
return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype,
DownloadHistory.title == title, DownloadHistory.title == title,
DownloadHistory.year == year, DownloadHistory.year == year,
DownloadHistory.seasons == season).order_by( DownloadHistory.seasons == season).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
return db.query(DownloadHistory).filter(DownloadHistory.type == mtype, result = db.query(DownloadHistory).filter(DownloadHistory.type == mtype,
DownloadHistory.title == title, DownloadHistory.title == title,
DownloadHistory.year == year, DownloadHistory.year == year,
DownloadHistory.seasons == season, DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by( DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
if result:
return list(result)
@staticmethod @staticmethod
@db_query
def list_by_user_date(db: Session, date: str, userid: str = None): def list_by_user_date(db: Session, date: str, userid: str = None):
""" """
查询某用户某时间之后的下载历史 查询某用户某时间之后的下载历史
""" """
if userid: if userid:
return db.query(DownloadHistory).filter(DownloadHistory.date < date, result = db.query(DownloadHistory).filter(DownloadHistory.date < date,
DownloadHistory.userid == userid).order_by( DownloadHistory.userid == userid).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
else: else:
return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by( result = db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
return list(result)
class DownloadFiles(Base): class DownloadFiles(Base):
@ -131,24 +143,30 @@ class DownloadFiles(Base):
state = Column(Integer, nullable=False, default=1) state = Column(Integer, nullable=False, default=1)
@staticmethod @staticmethod
@db_query
def get_by_hash(db: Session, download_hash: str, state: int = None): def get_by_hash(db: Session, download_hash: str, state: int = None):
if state: if state:
return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash, result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash,
DownloadFiles.state == state).all() DownloadFiles.state == state).all()
else: else:
return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all() result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all()
return list(result)
@staticmethod @staticmethod
@db_query
def get_by_fullpath(db: Session, fullpath: str): def get_by_fullpath(db: Session, fullpath: str):
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by( return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).first() DownloadFiles.id.desc()).first()
@staticmethod @staticmethod
@db_query
def get_by_savepath(db: Session, savepath: str): def get_by_savepath(db: Session, savepath: str):
return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() result = db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all()
return list(result)
@staticmethod @staticmethod
@db_persist @db_update
def delete_by_fullpath(db: Session, fullpath: str): def delete_by_fullpath(db: Session, fullpath: str):
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath, db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath,
DownloadFiles.state == 1).update( DownloadFiles.state == 1).update(

View File

@ -3,7 +3,8 @@ from datetime import datetime
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base, db_persist from app.db import db_query
from app.db.models import Base, db_update
class MediaServerItem(Base): class MediaServerItem(Base):
@ -41,20 +42,23 @@ class MediaServerItem(Base):
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@staticmethod @staticmethod
@db_query
def get_by_itemid(db: Session, item_id: str): def get_by_itemid(db: Session, item_id: str):
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first() return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first()
@staticmethod @staticmethod
@db_persist @db_update
def empty(db: Session, server: str): def empty(db: Session, server: str):
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
@staticmethod @staticmethod
@db_query
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):
return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid, return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid,
MediaServerItem.item_type == mtype).first() MediaServerItem.item_type == mtype).first()
@staticmethod @staticmethod
@db_query
def exists_by_title(db: Session, title: str, mtype: str, year: str): def exists_by_title(db: Session, title: str, mtype: str, year: str):
return db.query(MediaServerItem).filter(MediaServerItem.title == title, return db.query(MediaServerItem).filter(MediaServerItem.title == title,
MediaServerItem.item_type == mtype, MediaServerItem.item_type == mtype,

View File

@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base, db_persist from app.db import db_query
from app.db.models import Base, db_update
class PluginData(Base): class PluginData(Base):
@ -14,18 +15,23 @@ class PluginData(Base):
value = Column(String) value = Column(String)
@staticmethod @staticmethod
@db_query
def get_plugin_data(db: Session, plugin_id: str): def get_plugin_data(db: Session, plugin_id: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all()
return list(result)
@staticmethod @staticmethod
@db_query
def get_plugin_data_by_key(db: Session, plugin_id: str, key: str): def get_plugin_data_by_key(db: Session, plugin_id: str, key: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first() return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first()
@staticmethod @staticmethod
@db_persist @db_update
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
@staticmethod @staticmethod
@db_query
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all()
return list(result)

View File

@ -3,7 +3,8 @@ from datetime import datetime
from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy import Boolean, Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base, db_persist from app.db import db_query
from app.db.models import Base, db_update
class Site(Base): class Site(Base):
@ -47,18 +48,23 @@ class Site(Base):
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@staticmethod @staticmethod
@db_query
def get_by_domain(db: Session, domain: str): def get_by_domain(db: Session, domain: str):
return db.query(Site).filter(Site.domain == domain).first() return db.query(Site).filter(Site.domain == domain).first()
@staticmethod @staticmethod
@db_query
def get_actives(db: Session): def get_actives(db: Session):
return db.query(Site).filter(Site.is_active == 1).all() result = db.query(Site).filter(Site.is_active == 1).all()
return list(result)
@staticmethod @staticmethod
@db_query
def list_order_by_pri(db: Session): def list_order_by_pri(db: Session):
return db.query(Site).order_by(Site.pri).all() result = db.query(Site).order_by(Site.pri).all()
return list(result)
@staticmethod @staticmethod
@db_persist @db_update
def reset(db: Session): def reset(db: Session):
db.query(Site).delete() db.query(Site).delete()

View File

@ -1,6 +1,7 @@
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query
from app.db.models import Base from app.db.models import Base
@ -19,5 +20,6 @@ class SiteIcon(Base):
base64 = Column(String) base64 = Column(String)
@staticmethod @staticmethod
@db_query
def get_by_domain(db: Session, domain: str): def get_by_domain(db: Session, domain: str):
return db.query(SiteIcon).filter(SiteIcon.domain == domain).first() return db.query(SiteIcon).filter(SiteIcon.domain == domain).first()

View File

@ -1,6 +1,7 @@
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_update, db_query
from app.db.models import Base from app.db.models import Base
@ -67,6 +68,7 @@ class Subscribe(Base):
current_priority = Column(Integer) current_priority = Column(Integer)
@staticmethod @staticmethod
@db_query
def exists(db: Session, tmdbid: int, season: int = None): def exists(db: Session, tmdbid: int, season: int = None):
if season: if season:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
@ -74,30 +76,39 @@ class Subscribe(Base):
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first() return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first()
@staticmethod @staticmethod
@db_query
def get_by_state(db: Session, state: str): def get_by_state(db: Session, state: str):
return db.query(Subscribe).filter(Subscribe.state == state).all() result = db.query(Subscribe).filter(Subscribe.state == state).all()
return list(result)
@staticmethod @staticmethod
@db_query
def get_by_tmdbid(db: Session, tmdbid: int, season: int = None): def get_by_tmdbid(db: Session, tmdbid: int, season: int = None):
if season: if season:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
Subscribe.season == season).all() Subscribe.season == season).all()
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all() else:
result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all()
return list(result)
@staticmethod @staticmethod
@db_query
def get_by_title(db: Session, title: str): def get_by_title(db: Session, title: str):
return db.query(Subscribe).filter(Subscribe.name == title).first() return db.query(Subscribe).filter(Subscribe.name == title).first()
@staticmethod @staticmethod
@db_query
def get_by_doubanid(db: Session, doubanid: str): def get_by_doubanid(db: Session, doubanid: str):
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first()
@db_update
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int): def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
subscrbies = self.get_by_tmdbid(db, tmdbid, season) subscrbies = self.get_by_tmdbid(db, tmdbid, season)
for subscrbie in subscrbies: for subscrbie in subscrbies:
subscrbie.delete(db, subscrbie.id) subscrbie.delete(db, subscrbie.id)
return True return True
@db_update
def delete_by_doubanid(self, db: Session, doubanid: str): def delete_by_doubanid(self, db: Session, doubanid: str):
subscribe = self.get_by_doubanid(db, doubanid) subscribe = self.get_by_doubanid(db, doubanid)
if subscribe: if subscribe:

View File

@ -1,6 +1,7 @@
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_update, db_query
from app.db.models import Base from app.db.models import Base
@ -15,9 +16,11 @@ class SystemConfig(Base):
value = Column(String, nullable=True) value = Column(String, nullable=True)
@staticmethod @staticmethod
@db_query
def get_by_key(db: Session, key: str): def get_by_key(db: Session, key: str):
return db.query(SystemConfig).filter(SystemConfig.key == key).first() return db.query(SystemConfig).filter(SystemConfig.key == key).first()
@db_update
def delete_by_key(self, db: Session, key: str): def delete_by_key(self, db: Session, key: str):
systemconfig = self.get_by_key(db, key) systemconfig = self.get_by_key(db, key)
if systemconfig: if systemconfig:

View File

@ -3,7 +3,8 @@ import time
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func from sqlalchemy import Column, Integer, String, Sequence, Boolean, func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base, db_persist from app.db import db_query
from app.db.models import Base, db_update
class TransferHistory(Base): class TransferHistory(Base):
@ -47,29 +48,38 @@ class TransferHistory(Base):
files = Column(String) files = Column(String)
@staticmethod @staticmethod
@db_query
def list_by_title(db: Session, title: str, page: int = 1, count: int = 30): def list_by_title(db: Session, title: str, page: int = 1, count: int = 30):
return db.query(TransferHistory).filter(TransferHistory.title.like(f'%{title}%')).order_by( result = db.query(TransferHistory).filter(TransferHistory.title.like(f'%{title}%')).order_by(
TransferHistory.date.desc()).offset((page - 1) * count).limit( TransferHistory.date.desc()).offset((page - 1) * count).limit(
count).all() count).all()
return list(result)
@staticmethod @staticmethod
@db_query
def list_by_page(db: Session, page: int = 1, count: int = 30): def list_by_page(db: Session, page: int = 1, count: int = 30):
return db.query(TransferHistory).order_by(TransferHistory.date.desc()).offset((page - 1) * count).limit( result = db.query(TransferHistory).order_by(TransferHistory.date.desc()).offset((page - 1) * count).limit(
count).all() count).all()
return list(result)
@staticmethod @staticmethod
@db_query
def get_by_hash(db: Session, download_hash: str): def get_by_hash(db: Session, download_hash: str):
return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).first() return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).first()
@staticmethod @staticmethod
@db_query
def get_by_src(db: Session, src: str): def get_by_src(db: Session, src: str):
return db.query(TransferHistory).filter(TransferHistory.src == src).first() return db.query(TransferHistory).filter(TransferHistory.src == src).first()
@staticmethod @staticmethod
@db_query
def list_by_hash(db: Session, download_hash: str): def list_by_hash(db: Session, download_hash: str):
return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all() result = db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all()
return list(result)
@staticmethod @staticmethod
@db_query
def statistic(db: Session, days: int = 7): def statistic(db: Session, days: int = 7):
""" """
统计最近days天的下载历史数量按日期分组返回每日数量 统计最近days天的下载历史数量按日期分组返回每日数量
@ -78,74 +88,82 @@ class TransferHistory(Base):
TransferHistory.id.label('id')).filter( TransferHistory.id.label('id')).filter(
TransferHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S", TransferHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery() time.localtime(time.time() - 86400 * days))).subquery()
return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all() result = db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all()
return list(result)
@staticmethod @staticmethod
@db_query
def count(db: Session): def count(db: Session):
return db.query(func.count(TransferHistory.id)).first()[0] return db.query(func.count(TransferHistory.id)).first()[0]
@staticmethod @staticmethod
@db_query
def count_by_title(db: Session, title: str): def count_by_title(db: Session, title: str):
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.title.like(f'%{title}%')).first()[0] return db.query(func.count(TransferHistory.id)).filter(TransferHistory.title.like(f'%{title}%')).first()[0]
@staticmethod @staticmethod
@db_query
def list_by(db: Session, mtype: str = None, title: str = None, year: str = None, season: str = None, def list_by(db: Session, mtype: str = None, title: str = None, year: str = None, season: str = None,
episode: str = None, tmdbid: int = None, dest: str = None): episode: str = None, tmdbid: int = None, dest: str = None):
""" """
据tmdbidseasonseason_episode查询转移记录 据tmdbidseasonseason_episode查询转移记录
tmdbid + mtype title + year 必输 tmdbid + mtype title + year 必输
""" """
result = None
# TMDBID + 类型 # TMDBID + 类型
if tmdbid and mtype: if tmdbid and mtype:
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype, TransferHistory.type == mtype,
TransferHistory.seasons == season, TransferHistory.seasons == season,
TransferHistory.episodes == episode, TransferHistory.episodes == episode,
TransferHistory.dest == dest).all() TransferHistory.dest == dest).all()
# 电视剧某季 # 电视剧某季
elif season: elif season:
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype, TransferHistory.type == mtype,
TransferHistory.seasons == season).all() TransferHistory.seasons == season).all()
else: else:
if dest: if dest:
# 电影 # 电影
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype, TransferHistory.type == mtype,
TransferHistory.dest == dest).all() TransferHistory.dest == dest).all()
else: else:
# 电视剧所有季集 # 电视剧所有季集
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype).all() TransferHistory.type == mtype).all()
# 标题 + 年份 # 标题 + 年份
elif title and year: elif title and year:
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
return db.query(TransferHistory).filter(TransferHistory.title == title, result = db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year, TransferHistory.year == year,
TransferHistory.seasons == season, TransferHistory.seasons == season,
TransferHistory.episodes == episode, TransferHistory.episodes == episode,
TransferHistory.dest == dest).all() TransferHistory.dest == dest).all()
# 电视剧某季 # 电视剧某季
elif season: elif season:
return db.query(TransferHistory).filter(TransferHistory.title == title, result = db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year, TransferHistory.year == year,
TransferHistory.seasons == season).all() TransferHistory.seasons == season).all()
else: else:
if dest: if dest:
# 电影 # 电影
return db.query(TransferHistory).filter(TransferHistory.title == title, result = db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year, TransferHistory.year == year,
TransferHistory.dest == dest).all() TransferHistory.dest == dest).all()
else: else:
# 电视剧所有季集 # 电视剧所有季集
return db.query(TransferHistory).filter(TransferHistory.title == title, result = db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year).all() TransferHistory.year == year).all()
if result:
return list(result)
return [] return []
@staticmethod @staticmethod
@db_query
def get_by_type_tmdbid(db: Session, mtype: str = None, tmdbid: int = None): def get_by_type_tmdbid(db: Session, mtype: str = None, tmdbid: int = None):
""" """
据tmdbidtype查询转移记录 据tmdbidtype查询转移记录
@ -154,7 +172,7 @@ class TransferHistory(Base):
TransferHistory.type == mtype).first() TransferHistory.type == mtype).first()
@staticmethod @staticmethod
@db_persist @db_update
def update_download_hash(db: Session, historyid: int = None, download_hash: str = None): def update_download_hash(db: Session, historyid: int = None, download_hash: str = None):
db.query(TransferHistory).filter(TransferHistory.id == historyid).update( db.query(TransferHistory).filter(TransferHistory.id == historyid).update(
{ {

View File

@ -2,6 +2,7 @@ from sqlalchemy import Boolean, Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.security import verify_password from app.core.security import verify_password
from app.db import db_update, db_query
from app.db.models import Base from app.db.models import Base
@ -25,6 +26,7 @@ class User(Base):
avatar = Column(String) avatar = Column(String)
@staticmethod @staticmethod
@db_query
def authenticate(db: Session, name: str, password: str): def authenticate(db: Session, name: str, password: str):
user = db.query(User).filter(User.name == name).first() user = db.query(User).filter(User.name == name).first()
if not user: if not user:
@ -34,9 +36,11 @@ class User(Base):
return user return user
@staticmethod @staticmethod
@db_query
def get_by_name(db: Session, name: str): def get_by_name(db: Session, name: str):
return db.query(User).filter(User.name == name).first() return db.query(User).filter(User.name == name).first()
@db_update
def delete_by_name(self, db: Session, name: str): def delete_by_name(self, db: Session, name: str):
user = self.get_by_name(db, name) user = self.get_by_name(db, name)
if user: if user: