fix 数据库会话管理

This commit is contained in:
jxxghp 2023-08-16 10:22:45 +08:00
parent b086bbf015
commit da93328d50
29 changed files with 255 additions and 172 deletions

View File

@ -19,11 +19,12 @@ router = APIRouter()
@router.get("/statistic", summary="媒体数量统计", response_model=schemas.Statistic) @router.get("/statistic", summary="媒体数量统计", response_model=schemas.Statistic)
def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any: def statistic(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询媒体数量统计信息 查询媒体数量统计信息
""" """
media_statistic = DashboardChain().media_statistic() media_statistic = DashboardChain(db).media_statistic()
if media_statistic: if media_statistic:
return schemas.Statistic( return schemas.Statistic(
movie_count=media_statistic.movie_count, movie_count=media_statistic.movie_count,
@ -61,11 +62,12 @@ def processes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/downloader", summary="下载器信息", response_model=schemas.DownloaderInfo) @router.get("/downloader", summary="下载器信息", response_model=schemas.DownloaderInfo)
def downloader(_: schemas.TokenPayload = Depends(verify_token)) -> Any: def downloader(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询下载器信息 查询下载器信息
""" """
transfer_info = DashboardChain().downloader_info() transfer_info = DashboardChain(db).downloader_info()
free_space = SystemUtils.free_space(Path(settings.DOWNLOAD_PATH)) free_space = SystemUtils.free_space(Path(settings.DOWNLOAD_PATH))
return schemas.DownloaderInfo( return schemas.DownloaderInfo(
download_speed=transfer_info.download_speed, download_speed=transfer_info.download_speed,

View File

@ -1,12 +1,14 @@
from typing import List, Any from typing import List, Any
from fastapi import APIRouter, Depends, Response from fastapi import APIRouter, Depends, Response
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.douban import DoubanChain from app.chain.douban import DoubanChain
from app.core.config import settings from app.core.config import settings
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db
from app.schemas import MediaType from app.schemas import MediaType
from app.utils.http import RequestUtils from app.utils.http import RequestUtils
@ -30,12 +32,13 @@ def douban_img(imgurl: str) -> Any:
@router.get("/recognize/{doubanid}", summary="豆瓣ID识别", response_model=schemas.Context) @router.get("/recognize/{doubanid}", summary="豆瓣ID识别", response_model=schemas.Context)
def recognize_doubanid(doubanid: str, def recognize_doubanid(doubanid: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据豆瓣ID识别媒体信息 根据豆瓣ID识别媒体信息
""" """
# 识别媒体信息 # 识别媒体信息
context = DoubanChain().recognize_by_doubanid(doubanid=doubanid) context = DoubanChain(db).recognize_by_doubanid(doubanid=doubanid)
if context: if context:
return context.to_dict() return context.to_dict()
else: else:
@ -47,11 +50,12 @@ def douban_movies(sort: str = "R",
tags: str = "", tags: str = "",
page: int = 1, page: int = 1,
count: int = 30, count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览豆瓣电影信息 浏览豆瓣电影信息
""" """
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE, movies = DoubanChain(db).douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count) sort=sort, tags=tags, page=page, count=count)
if not movies: if not movies:
return [] return []
@ -67,11 +71,12 @@ def douban_tvs(sort: str = "R",
tags: str = "", tags: str = "",
page: int = 1, page: int = 1,
count: int = 30, count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览豆瓣剧集信息 浏览豆瓣剧集信息
""" """
tvs = DoubanChain().douban_discover(mtype=MediaType.TV, tvs = DoubanChain(db).douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count) sort=sort, tags=tags, page=page, count=count)
if not tvs: if not tvs:
return [] return []
@ -86,42 +91,47 @@ def douban_tvs(sort: str = "R",
@router.get("/movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo]) @router.get("/movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
def movie_top250(page: int = 1, def movie_top250(page: int = 1,
count: int = 30, count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览豆瓣剧集信息 浏览豆瓣剧集信息
""" """
movies = DoubanChain().movie_top250(page=page, count=count) movies = DoubanChain(db).movie_top250(page=page, count=count)
return [MediaInfo(douban_info=movie).to_dict() for movie in movies] return [MediaInfo(douban_info=movie).to_dict() for movie in movies]
@router.get("/tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo]) @router.get("/tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
def tv_weekly_chinese(page: int = 1, def tv_weekly_chinese(page: int = 1,
count: int = 30, count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
中国每周剧集口碑榜 中国每周剧集口碑榜
""" """
tvs = DoubanChain().tv_weekly_chinese(page=page, count=count) tvs = DoubanChain(db).tv_weekly_chinese(page=page, count=count)
return [MediaInfo(douban_info=tv).to_dict() for tv in tvs] return [MediaInfo(douban_info=tv).to_dict() for tv in tvs]
@router.get("/tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo]) @router.get("/tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
def tv_weekly_global(page: int = 1, def tv_weekly_global(page: int = 1,
count: int = 30, count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
全球每周剧集口碑榜 全球每周剧集口碑榜
""" """
tvs = DoubanChain().tv_weekly_global(page=page, count=count) tvs = DoubanChain(db).tv_weekly_global(page=page, count=count)
return [MediaInfo(douban_info=tv).to_dict() for tv in tvs] return [MediaInfo(douban_info=tv).to_dict() for tv in tvs]
@router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo) @router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo)
def douban_info(doubanid: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any: def douban_info(doubanid: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据豆瓣ID查询豆瓣媒体信息 根据豆瓣ID查询豆瓣媒体信息
""" """
doubaninfo = DoubanChain().douban_info(doubanid=doubanid) doubaninfo = DoubanChain(db).douban_info(doubanid=doubanid)
if doubaninfo: if doubaninfo:
return MediaInfo(douban_info=doubaninfo).to_dict() return MediaInfo(douban_info=doubaninfo).to_dict()
else: else:

View File

@ -1,6 +1,7 @@
from typing import Any, List from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.douban import DoubanChain from app.chain.douban import DoubanChain
@ -9,6 +10,7 @@ from app.chain.media import MediaChain
from app.core.context import MediaInfo, Context, TorrentInfo from app.core.context import MediaInfo, Context, TorrentInfo
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db
from app.db.models.user import User from app.db.models.user import User
from app.db.userauth import get_current_active_superuser from app.db.userauth import get_current_active_superuser
from app.schemas import NotExistMediaInfo, MediaType from app.schemas import NotExistMediaInfo, MediaType
@ -18,18 +20,20 @@ router = APIRouter()
@router.get("/", summary="正在下载", response_model=List[schemas.DownloadingTorrent]) @router.get("/", summary="正在下载", response_model=List[schemas.DownloadingTorrent])
def read_downloading( def read_downloading(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询正在下载的任务 查询正在下载的任务
""" """
return DownloadChain().downloading() return DownloadChain(db).downloading()
@router.post("/", summary="添加下载", response_model=schemas.Response) @router.post("/", summary="添加下载", response_model=schemas.Response)
def add_downloading( def add_downloading(
media_in: schemas.MediaInfo, media_in: schemas.MediaInfo,
torrent_in: schemas.TorrentInfo, torrent_in: schemas.TorrentInfo,
current_user: User = Depends(get_current_active_superuser)) -> Any: db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
添加下载任务 添加下载任务
""" """
@ -47,7 +51,7 @@ def add_downloading(
media_info=mediainfo, media_info=mediainfo,
torrent_info=torrentinfo torrent_info=torrentinfo
) )
did = DownloadChain().download_single(context=context) did = DownloadChain(db).download_single(context=context)
return schemas.Response(success=True if did else False, data={ return schemas.Response(success=True if did else False, data={
"download_id": did "download_id": did
}) })
@ -55,6 +59,7 @@ def add_downloading(
@router.post("/notexists", summary="查询缺失媒体信息", response_model=List[NotExistMediaInfo]) @router.post("/notexists", summary="查询缺失媒体信息", response_model=List[NotExistMediaInfo])
def exists(media_in: schemas.MediaInfo, def exists(media_in: schemas.MediaInfo,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询缺失媒体信息 查询缺失媒体信息
@ -65,19 +70,19 @@ def exists(media_in: schemas.MediaInfo,
if media_in.tmdb_id: if media_in.tmdb_id:
mediainfo.from_dict(media_in.dict()) mediainfo.from_dict(media_in.dict())
elif media_in.douban_id: elif media_in.douban_id:
context = DoubanChain().recognize_by_doubanid(doubanid=media_in.douban_id) context = DoubanChain(db).recognize_by_doubanid(doubanid=media_in.douban_id)
if context: if context:
mediainfo = context.media_info mediainfo = context.media_info
meta = context.meta_info meta = context.meta_info
else: else:
context = MediaChain().recognize_by_title(title=f"{media_in.title} {media_in.year}") context = MediaChain(db).recognize_by_title(title=f"{media_in.title} {media_in.year}")
if context: if context:
mediainfo = context.media_info mediainfo = context.media_info
meta = context.meta_info meta = context.meta_info
# 查询缺失信息 # 查询缺失信息
if not mediainfo or not mediainfo.tmdb_id: if not mediainfo or not mediainfo.tmdb_id:
raise HTTPException(status_code=404, detail="媒体信息不存在") raise HTTPException(status_code=404, detail="媒体信息不存在")
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo) exist_flag, no_exists = DownloadChain(db).get_no_exists_info(meta=meta, mediainfo=mediainfo)
if mediainfo.type == MediaType.MOVIE: if mediainfo.type == MediaType.MOVIE:
# 电影已存在时返回空列表,存在时返回空对像列表 # 电影已存在时返回空列表,存在时返回空对像列表
return [] if exist_flag else [NotExistMediaInfo()] return [] if exist_flag else [NotExistMediaInfo()]
@ -90,31 +95,34 @@ def exists(media_in: schemas.MediaInfo,
@router.put("/{hashString}/start", summary="开始任务", response_model=schemas.Response) @router.put("/{hashString}/start", summary="开始任务", response_model=schemas.Response)
def start_downloading( def start_downloading(
hashString: str, hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
开如下载任务 开如下载任务
""" """
ret = DownloadChain().set_downloading(hashString, "start") ret = DownloadChain(db).set_downloading(hashString, "start")
return schemas.Response(success=True if ret else False) return schemas.Response(success=True if ret else False)
@router.put("/{hashString}/stop", summary="暂停任务", response_model=schemas.Response) @router.put("/{hashString}/stop", summary="暂停任务", response_model=schemas.Response)
def stop_downloading( def stop_downloading(
hashString: str, hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
控制下载任务 控制下载任务
""" """
ret = DownloadChain().set_downloading(hashString, "stop") ret = DownloadChain(db).set_downloading(hashString, "stop")
return schemas.Response(success=True if ret else False) return schemas.Response(success=True if ret else False)
@router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response) @router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response)
def remove_downloading( def remove_downloading(
hashString: str, hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
控制下载任务 控制下载任务
""" """
ret = DownloadChain().remove_downloading(hashString) ret = DownloadChain(db).remove_downloading(hashString)
return schemas.Response(success=True if ret else False) return schemas.Response(success=True if ret else False)

View File

@ -74,7 +74,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
if not history: if not history:
return schemas.Response(success=False, msg="记录不存在") return schemas.Response(success=False, msg="记录不存在")
# 册除文件 # 册除文件
TransferChain().delete_files(Path(history.dest)) TransferChain(db).delete_files(Path(history.dest))
# 删除记录 # 删除记录
TransferHistory.delete(db, history_in.id) TransferHistory.delete(db, history_in.id)
return schemas.Response(success=True) return schemas.Response(success=True)
@ -84,12 +84,13 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
def redo_transfer_history(history_in: schemas.TransferHistory, def redo_transfer_history(history_in: schemas.TransferHistory,
mtype: str, mtype: str,
new_tmdbid: int, new_tmdbid: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
历史记录重新转移 历史记录重新转移
""" """
hash_str = history_in.download_hash hash_str = history_in.download_hash
result = TransferChain().process(f"{hash_str} {new_tmdbid}|{mtype}") result = TransferChain(db).process(f"{hash_str} {new_tmdbid}|{mtype}")
if result: if result:
return schemas.Response(success=True) return schemas.Response(success=True)
else: else:

View File

@ -36,7 +36,7 @@ async def login_access_token(
if not user: if not user:
# 请求协助认证 # 请求协助认证
logger.warn("登录用户本地不匹配,尝试辅助认证 ...") logger.warn("登录用户本地不匹配,尝试辅助认证 ...")
token = UserChain().user_authenticate(form_data.username, form_data.password) token = UserChain(db).user_authenticate(form_data.username, form_data.password)
if not token: if not token:
raise HTTPException(status_code=401, detail="用户名或密码不正确") raise HTTPException(status_code=401, detail="用户名或密码不正确")
else: else:
@ -83,11 +83,11 @@ def bing_wallpaper() -> Any:
@router.get("/tmdb", summary="TMDB电影海报", response_model=schemas.Response) @router.get("/tmdb", summary="TMDB电影海报", response_model=schemas.Response)
def tmdb_wallpaper() -> Any: def tmdb_wallpaper(db: Session = Depends(get_db)) -> Any:
""" """
获取TMDB电影海报 获取TMDB电影海报
""" """
infos = TmdbChain().tmdb_trending() infos = TmdbChain(db).tmdb_trending()
if infos: if infos:
# 随机一个电影 # 随机一个电影
while True: while True:

View File

@ -20,12 +20,13 @@ router = APIRouter()
@router.get("/recognize", summary="识别媒体信息", response_model=schemas.Context) @router.get("/recognize", summary="识别媒体信息", response_model=schemas.Context)
def recognize(title: str, def recognize(title: str,
subtitle: str = None, subtitle: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据标题副标题识别媒体信息 根据标题副标题识别媒体信息
""" """
# 识别媒体信息 # 识别媒体信息
context = MediaChain().recognize_by_title(title=title, subtitle=subtitle) context = MediaChain(db).recognize_by_title(title=title, subtitle=subtitle)
if context: if context:
return context.to_dict() return context.to_dict()
return schemas.Context() return schemas.Context()
@ -35,11 +36,12 @@ def recognize(title: str,
def search_by_title(title: str, def search_by_title(title: str,
page: int = 1, page: int = 1,
count: int = 8, count: int = 8,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
模糊搜索媒体信息列表 模糊搜索媒体信息列表
""" """
_, medias = MediaChain().search(title=title) _, medias = MediaChain(db).search(title=title)
if medias: if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return [] return []
@ -69,20 +71,21 @@ def exists(title: str = None,
@router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo) @router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo)
def tmdb_info(mediaid: str, type_name: str, def tmdb_info(mediaid: str, type_name: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据媒体ID查询themoviedb或豆瓣媒体信息type_name: 电影/电视剧 根据媒体ID查询themoviedb或豆瓣媒体信息type_name: 电影/电视剧
""" """
mtype = MediaType(type_name) mtype = MediaType(type_name)
if mediaid.startswith("tmdb:"): if mediaid.startswith("tmdb:"):
result = TmdbChain().tmdb_info(int(mediaid[5:]), mtype) result = TmdbChain(db).tmdb_info(int(mediaid[5:]), mtype)
return MediaInfo(tmdb_info=result).to_dict() return MediaInfo(tmdb_info=result).to_dict()
elif mediaid.startswith("douban:"): elif mediaid.startswith("douban:"):
# 查询豆瓣信息 # 查询豆瓣信息
doubaninfo = DoubanChain().douban_info(doubanid=mediaid[7:]) doubaninfo = DoubanChain(db).douban_info(doubanid=mediaid[7:])
if not doubaninfo: if not doubaninfo:
return schemas.MediaInfo() return schemas.MediaInfo()
result = DoubanChain().recognize_by_doubaninfo(doubaninfo) result = DoubanChain(db).recognize_by_doubaninfo(doubaninfo)
if result: if result:
# TMDB # TMDB
return result.media_info.to_dict() return result.media_info.to_dict()

View File

@ -19,22 +19,23 @@ from app.schemas.types import SystemConfigKey, NotificationType
router = APIRouter() router = APIRouter()
def start_message_chain(body: Any, form: Any, args: Any): def start_message_chain(db: Session, body: Any, form: Any, args: Any):
""" """
启动链式任务 启动链式任务
""" """
MessageChain().process(body=body, form=form, args=args) MessageChain(db).process(body=body, form=form, args=args)
@router.post("/", summary="接收用户消息", response_model=schemas.Response) @router.post("/", summary="接收用户消息", response_model=schemas.Response)
async def user_message(background_tasks: BackgroundTasks, request: Request): async def user_message(background_tasks: BackgroundTasks, request: Request,
db: Session = Depends(get_db)):
""" """
用户消息响应 用户消息响应
""" """
body = await request.body() body = await request.body()
form = await request.form() form = await request.form()
args = request.query_params args = request.query_params
background_tasks.add_task(start_message_chain, body, form, args) background_tasks.add_task(start_message_chain, db, body, form, args)
return schemas.Response(success=True) return schemas.Response(success=True)

View File

@ -15,11 +15,11 @@ from app.schemas import MediaType
router = APIRouter() router = APIRouter()
def start_rss_refresh(rssid: int = None): def start_rss_refresh(db: Session, rssid: int = None):
""" """
启动自定义订阅刷新 启动自定义订阅刷新
""" """
RssChain().refresh(rssid=rssid, manual=True) RssChain(db).refresh(rssid=rssid, manual=True)
@router.get("/", summary="所有自定义订阅", response_model=List[schemas.Rss]) @router.get("/", summary="所有自定义订阅", response_model=List[schemas.Rss])
@ -36,6 +36,7 @@ def read_rsses(
def create_rss( def create_rss(
*, *,
rss_in: schemas.Rss, rss_in: schemas.Rss,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token) _: schemas.TokenPayload = Depends(verify_token)
) -> Any: ) -> Any:
""" """
@ -45,7 +46,7 @@ def create_rss(
mtype = MediaType(rss_in.type) mtype = MediaType(rss_in.type)
else: else:
mtype = None mtype = None
rssid, errormsg = RssChain().add( rssid, errormsg = RssChain(db).add(
mtype=mtype, mtype=mtype,
**rss_in.dict() **rss_in.dict()
) )
@ -100,11 +101,13 @@ def preview_rss(
def refresh_rss( def refresh_rss(
rssid: int, rssid: int,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据ID刷新自定义订阅 根据ID刷新自定义订阅
""" """
background_tasks.add_task(start_rss_refresh, background_tasks.add_task(start_rss_refresh,
db=db,
rssid=rssid) rssid=rssid)
return schemas.Response(success=True) return schemas.Response(success=True)

View File

@ -1,28 +1,32 @@
from typing import List, Any from typing import List, Any
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.douban import DoubanChain from app.chain.douban import DoubanChain
from app.chain.search import SearchChain from app.chain.search import SearchChain
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db
from app.schemas.types import MediaType from app.schemas.types import MediaType
router = APIRouter() router = APIRouter()
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context]) @router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def search_latest(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询搜索结果 查询搜索结果
""" """
torrents = SearchChain().last_search_results() torrents = SearchChain(db).last_search_results()
return [torrent.to_dict() for torrent in torrents] return [torrent.to_dict() for torrent in torrents]
@router.get("/media/{mediaid}", summary="精确搜索资源", response_model=List[schemas.Context]) @router.get("/media/{mediaid}", summary="精确搜索资源", response_model=List[schemas.Context])
def search_by_tmdbid(mediaid: str, def search_by_tmdbid(mediaid: str,
mtype: str = None, mtype: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/ 根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/
@ -31,14 +35,14 @@ def search_by_tmdbid(mediaid: str,
tmdbid = int(mediaid.replace("tmdb:", "")) tmdbid = int(mediaid.replace("tmdb:", ""))
if mtype: if mtype:
mtype = MediaType(mtype) mtype = MediaType(mtype)
torrents = SearchChain().search_by_tmdbid(tmdbid=tmdbid, mtype=mtype) torrents = SearchChain(db).search_by_tmdbid(tmdbid=tmdbid, mtype=mtype)
elif mediaid.startswith("douban:"): elif mediaid.startswith("douban:"):
doubanid = mediaid.replace("douban:", "") doubanid = mediaid.replace("douban:", "")
# 识别豆瓣信息 # 识别豆瓣信息
context = DoubanChain().recognize_by_doubanid(doubanid) context = DoubanChain(db).recognize_by_doubanid(doubanid)
if not context or not context.media_info or not context.media_info.tmdb_id: if not context or not context.media_info or not context.media_info.tmdb_id:
raise HTTPException(status_code=404, detail="无法识别TMDB媒体信息") raise HTTPException(status_code=404, detail="无法识别TMDB媒体信息")
torrents = SearchChain().search_by_tmdbid(tmdbid=context.media_info.tmdb_id, torrents = SearchChain(db).search_by_tmdbid(tmdbid=context.media_info.tmdb_id,
mtype=context.media_info.type) mtype=context.media_info.type)
else: else:
return [] return []
@ -49,9 +53,10 @@ def search_by_tmdbid(mediaid: str,
async def search_by_title(keyword: str = None, async def search_by_title(keyword: str = None,
page: int = 0, page: int = 0,
site: int = None, site: int = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据名称模糊搜索站点资源支持分页关键词为空是返回首页资源 根据名称模糊搜索站点资源支持分页关键词为空是返回首页资源
""" """
torrents = SearchChain().search_by_title(title=keyword, page=page, site=site) torrents = SearchChain(db).search_by_title(title=keyword, page=page, site=site)
return [torrent.to_dict() for torrent in torrents] return [torrent.to_dict() for torrent in torrents]

View File

@ -19,11 +19,11 @@ from app.utils.string import StringUtils
router = APIRouter() router = APIRouter()
def start_cookiecloud_sync(): def start_cookiecloud_sync(db: Session):
""" """
后台启动CookieCloud站点同步 后台启动CookieCloud站点同步
""" """
CookieCloudChain().process(manual=True) CookieCloudChain(db).process(manual=True)
@router.get("/", summary="所有站点", response_model=List[schemas.Site]) @router.get("/", summary="所有站点", response_model=List[schemas.Site])
@ -67,11 +67,12 @@ def delete_site(
@router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response) @router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response)
def cookie_cloud_sync(background_tasks: BackgroundTasks, def cookie_cloud_sync(background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
运行CookieCloud同步站点信息 运行CookieCloud同步站点信息
""" """
background_tasks.add_task(start_cookiecloud_sync) background_tasks.add_task(start_cookiecloud_sync, db)
return schemas.Response(success=True, message="CookieCloud同步任务已启动") return schemas.Response(success=True, message="CookieCloud同步任务已启动")
@ -83,7 +84,7 @@ def cookie_cloud_sync(db: Session = Depends(get_db),
""" """
Site.reset(db) Site.reset(db)
SystemConfigOper(db).set(SystemConfigKey.IndexerSites, []) SystemConfigOper(db).set(SystemConfigKey.IndexerSites, [])
CookieCloudChain().process(manual=True) CookieCloudChain(db).process(manual=True)
return schemas.Response(success=True, message="站点已重置!") return schemas.Response(success=True, message="站点已重置!")
@ -105,7 +106,7 @@ def update_cookie(
detail=f"站点 {site_id} 不存在!", detail=f"站点 {site_id} 不存在!",
) )
# 更新Cookie # 更新Cookie
state, message = SiteChain().update_cookie(site_info=site_info, state, message = SiteChain(db).update_cookie(site_info=site_info,
username=username, username=username,
password=password) password=password)
return schemas.Response(success=state, message=message) return schemas.Response(success=state, message=message)
@ -124,7 +125,7 @@ def test_site(site_id: int,
status_code=404, status_code=404,
detail=f"站点 {site_id} 不存在", detail=f"站点 {site_id} 不存在",
) )
status, message = SiteChain().test(site.domain) status, message = SiteChain(db).test(site.domain)
return schemas.Response(success=status, message=message) return schemas.Response(success=status, message=message)
@ -162,7 +163,7 @@ def site_resource(site_id: int, keyword: str = None,
status_code=404, status_code=404,
detail=f"站点 {site_id} 不存在", detail=f"站点 {site_id} 不存在",
) )
torrents = SearchChain().browse(site.domain, keyword) torrents = SearchChain(db).browse(site.domain, keyword)
if not torrents: if not torrents:
return [] return []
return [torrent.to_dict() for torrent in torrents] return [torrent.to_dict() for torrent in torrents]

View File

@ -17,20 +17,20 @@ from app.schemas.types import MediaType
router = APIRouter() router = APIRouter()
def start_subscribe_add(title: str, year: str, def start_subscribe_add(db: Session, title: str, year: str,
mtype: MediaType, tmdbid: int, season: int, username: str): mtype: MediaType, tmdbid: int, season: int, username: str):
""" """
启动订阅任务 启动订阅任务
""" """
SubscribeChain().add(title=title, year=year, SubscribeChain(db).add(title=title, year=year,
mtype=mtype, tmdbid=tmdbid, season=season, username=username) mtype=mtype, tmdbid=tmdbid, season=season, username=username)
def start_subscribe_search(sid: Optional[int], state: Optional[str]): def start_subscribe_search(db: Session, sid: Optional[int], state: Optional[str]):
""" """
启动订阅搜索任务 启动订阅搜索任务
""" """
SubscribeChain().search(sid=sid, state=state, manual=True) SubscribeChain(db).search(sid=sid, state=state, manual=True)
@router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe]) @router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe])
@ -51,6 +51,7 @@ def read_subscribes(
def create_subscribe( def create_subscribe(
*, *,
subscribe_in: schemas.Subscribe, subscribe_in: schemas.Subscribe,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user),
) -> Any: ) -> Any:
""" """
@ -66,7 +67,7 @@ def create_subscribe(
title = subscribe_in.name title = subscribe_in.name
else: else:
title = None title = None
sid, message = SubscribeChain().add(mtype=mtype, sid, message = SubscribeChain(db).add(mtype=mtype,
title=title, title=title,
year=subscribe_in.year, year=subscribe_in.year,
tmdbid=subscribe_in.tmdbid, tmdbid=subscribe_in.tmdbid,
@ -171,6 +172,7 @@ def delete_subscribe(
@router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response) @router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response)
async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
authorization: str = Header(None)) -> Any: authorization: str = Header(None)) -> Any:
""" """
Jellyseerr/Overseerr订阅 Jellyseerr/Overseerr订阅
@ -198,6 +200,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
# 添加订阅 # 添加订阅
if media_type == MediaType.MOVIE: if media_type == MediaType.MOVIE:
background_tasks.add_task(start_subscribe_add, background_tasks.add_task(start_subscribe_add,
db=db,
mtype=media_type, mtype=media_type,
tmdbid=tmdbId, tmdbid=tmdbId,
title=subject, title=subject,
@ -212,6 +215,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
break break
for season in seasons: for season in seasons:
background_tasks.add_task(start_subscribe_add, background_tasks.add_task(start_subscribe_add,
db=db,
mtype=media_type, mtype=media_type,
tmdbid=tmdbId, tmdbid=tmdbId,
title=subject, title=subject,
@ -224,11 +228,12 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
@router.get("/refresh", summary="刷新订阅", response_model=schemas.Response) @router.get("/refresh", summary="刷新订阅", response_model=schemas.Response)
def refresh_subscribes( def refresh_subscribes(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
刷新所有订阅 刷新所有订阅
""" """
SubscribeChain().refresh() SubscribeChain(db).refresh()
return schemas.Response(success=True) return schemas.Response(success=True)
@ -236,20 +241,22 @@ def refresh_subscribes(
def search_subscribe( def search_subscribe(
subscribe_id: int, subscribe_id: int,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
搜索所有订阅 搜索所有订阅
""" """
background_tasks.add_task(start_subscribe_search, sid=subscribe_id, state=None) background_tasks.add_task(start_subscribe_search, db=db, sid=subscribe_id, state=None)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response) @router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
def search_subscribes( def search_subscribes(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
搜索所有订阅 搜索所有订阅
""" """
background_tasks.add_task(start_subscribe_search, sid=None, state='R') background_tasks.add_task(start_subscribe_search, db=db, sid=None, state='R')
return schemas.Response(success=True) return schemas.Response(success=True)

View File

@ -1,22 +1,25 @@
from typing import List, Any from typing import List, Any
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.tmdb import TmdbChain from app.chain.tmdb import TmdbChain
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db
from app.schemas.types import MediaType from app.schemas.types import MediaType
router = APIRouter() router = APIRouter()
@router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason]) @router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason])
def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any: def tmdb_seasons(tmdbid: int, db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询themoviedb所有季信息 根据TMDBID查询themoviedb所有季信息
""" """
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid) seasons_info = TmdbChain(db).tmdb_seasons(tmdbid=tmdbid)
if not seasons_info: if not seasons_info:
return [] return []
else: else:
@ -26,15 +29,16 @@ def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -
@router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo]) @router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_similar(tmdbid: int, def tmdb_similar(tmdbid: int,
type_name: str, type_name: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询类似电影/电视剧type_name: 电影/电视剧 根据TMDBID查询类似电影/电视剧type_name: 电影/电视剧
""" """
mediatype = MediaType(type_name) mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE: if mediatype == MediaType.MOVIE:
tmdbinfos = TmdbChain().movie_similar(tmdbid=tmdbid) tmdbinfos = TmdbChain(db).movie_similar(tmdbid=tmdbid)
elif mediatype == MediaType.TV: elif mediatype == MediaType.TV:
tmdbinfos = TmdbChain().tv_similar(tmdbid=tmdbid) tmdbinfos = TmdbChain(db).tv_similar(tmdbid=tmdbid)
else: else:
return [] return []
if not tmdbinfos: if not tmdbinfos:
@ -46,15 +50,16 @@ def tmdb_similar(tmdbid: int,
@router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo]) @router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_recommend(tmdbid: int, def tmdb_recommend(tmdbid: int,
type_name: str, type_name: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询推荐电影/电视剧type_name: 电影/电视剧 根据TMDBID查询推荐电影/电视剧type_name: 电影/电视剧
""" """
mediatype = MediaType(type_name) mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE: if mediatype == MediaType.MOVIE:
tmdbinfos = TmdbChain().movie_recommend(tmdbid=tmdbid) tmdbinfos = TmdbChain(db).movie_recommend(tmdbid=tmdbid)
elif mediatype == MediaType.TV: elif mediatype == MediaType.TV:
tmdbinfos = TmdbChain().tv_recommend(tmdbid=tmdbid) tmdbinfos = TmdbChain(db).tv_recommend(tmdbid=tmdbid)
else: else:
return [] return []
if not tmdbinfos: if not tmdbinfos:
@ -67,15 +72,16 @@ def tmdb_recommend(tmdbid: int,
def tmdb_credits(tmdbid: int, def tmdb_credits(tmdbid: int,
type_name: str, type_name: str,
page: int = 1, page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询演员阵容type_name: 电影/电视剧 根据TMDBID查询演员阵容type_name: 电影/电视剧
""" """
mediatype = MediaType(type_name) mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE: if mediatype == MediaType.MOVIE:
tmdbinfos = TmdbChain().movie_credits(tmdbid=tmdbid, page=page) tmdbinfos = TmdbChain(db).movie_credits(tmdbid=tmdbid, page=page)
elif mediatype == MediaType.TV: elif mediatype == MediaType.TV:
tmdbinfos = TmdbChain().tv_credits(tmdbid=tmdbid, page=page) tmdbinfos = TmdbChain(db).tv_credits(tmdbid=tmdbid, page=page)
else: else:
return [] return []
if not tmdbinfos: if not tmdbinfos:
@ -86,11 +92,12 @@ def tmdb_credits(tmdbid: int,
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.TmdbPerson) @router.get("/person/{person_id}", summary="人物详情", response_model=schemas.TmdbPerson)
def tmdb_person(person_id: int, def tmdb_person(person_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据人物ID查询人物详情 根据人物ID查询人物详情
""" """
tmdbinfo = TmdbChain().person_detail(person_id=person_id) tmdbinfo = TmdbChain(db).person_detail(person_id=person_id)
if not tmdbinfo: if not tmdbinfo:
return schemas.TmdbPerson() return schemas.TmdbPerson()
else: else:
@ -100,11 +107,12 @@ def tmdb_person(person_id: int,
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo]) @router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def tmdb_person_credits(person_id: int, def tmdb_person_credits(person_id: int,
page: int = 1, page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据人物ID查询人物参演作品 根据人物ID查询人物参演作品
""" """
tmdbinfo = TmdbChain().person_credits(person_id=person_id, page=page) tmdbinfo = TmdbChain(db).person_credits(person_id=person_id, page=page)
if not tmdbinfo: if not tmdbinfo:
return [] return []
else: else:
@ -116,11 +124,12 @@ def tmdb_movies(sort_by: str = "popularity.desc",
with_genres: str = "", with_genres: str = "",
with_original_language: str = "", with_original_language: str = "",
page: int = 1, page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览TMDB电影信息 浏览TMDB电影信息
""" """
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE, movies = TmdbChain(db).tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by, sort_by=sort_by,
with_genres=with_genres, with_genres=with_genres,
with_original_language=with_original_language, with_original_language=with_original_language,
@ -135,11 +144,12 @@ def tmdb_tvs(sort_by: str = "popularity.desc",
with_genres: str = "", with_genres: str = "",
with_original_language: str = "", with_original_language: str = "",
page: int = 1, page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览TMDB剧集信息 浏览TMDB剧集信息
""" """
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV, tvs = TmdbChain(db).tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by, sort_by=sort_by,
with_genres=with_genres, with_genres=with_genres,
with_original_language=with_original_language, with_original_language=with_original_language,
@ -151,11 +161,12 @@ def tmdb_tvs(sort_by: str = "popularity.desc",
@router.get("/trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo]) @router.get("/trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo])
def tmdb_trending(page: int = 1, def tmdb_trending(page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览TMDB剧集信息 浏览TMDB剧集信息
""" """
infos = TmdbChain().tmdb_trending(page=page) infos = TmdbChain(db).tmdb_trending(page=page)
if not infos: if not infos:
return [] return []
return [MediaInfo(tmdb_info=info).to_dict() for info in infos] return [MediaInfo(tmdb_info=info).to_dict() for info in infos]
@ -163,11 +174,12 @@ def tmdb_trending(page: int = 1,
@router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode]) @router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode])
def tmdb_season_episodes(tmdbid: int, season: int, def tmdb_season_episodes(tmdbid: int, season: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询某季的所有信信息 根据TMDBID查询某季的所有信信息
""" """
episodes_info = TmdbChain().tmdb_episodes(tmdbid=tmdbid, season=season) episodes_info = TmdbChain(db).tmdb_episodes(tmdbid=tmdbid, season=season)
if not episodes_info: if not episodes_info:
return [] return []
else: else:

View File

@ -1,24 +1,27 @@
from typing import Any from typing import Any
from fastapi import APIRouter, BackgroundTasks, Request from fastapi import APIRouter, BackgroundTasks, Request, Depends
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.webhook import WebhookChain from app.chain.webhook import WebhookChain
from app.core.config import settings from app.core.config import settings
from app.db import get_db
router = APIRouter() router = APIRouter()
def start_webhook_chain(body: Any, form: Any, args: Any): def start_webhook_chain(db: Session, body: Any, form: Any, args: Any):
""" """
启动链式任务 启动链式任务
""" """
WebhookChain().message(body=body, form=form, args=args) WebhookChain(db).message(body=body, form=form, args=args)
@router.post("/", summary="Webhook消息响应", response_model=schemas.Response) @router.post("/", summary="Webhook消息响应", response_model=schemas.Response)
async def webhook_message(background_tasks: BackgroundTasks, async def webhook_message(background_tasks: BackgroundTasks,
token: str, request: Request) -> Any: token: str, request: Request,
db: Session = Depends(get_db),) -> Any:
""" """
Webhook响应 Webhook响应
""" """
@ -27,18 +30,19 @@ async def webhook_message(background_tasks: BackgroundTasks,
body = await request.body() body = await request.body()
form = await request.form() form = await request.form()
args = request.query_params args = request.query_params
background_tasks.add_task(start_webhook_chain, body, form, args) background_tasks.add_task(start_webhook_chain, db, body, form, args)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/", summary="Webhook消息响应", response_model=schemas.Response) @router.get("/", summary="Webhook消息响应", response_model=schemas.Response)
async def webhook_message(background_tasks: BackgroundTasks, async def webhook_message(background_tasks: BackgroundTasks,
token: str, request: Request) -> Any: token: str, request: Request,
db: Session = Depends(get_db)) -> Any:
""" """
Webhook响应 Webhook响应
""" """
if token != settings.API_TOKEN: if token != settings.API_TOKEN:
return schemas.Response(success=False, message="token认证不通过") return schemas.Response(success=False, message="token认证不通过")
args = request.query_params args = request.query_params
background_tasks.add_task(start_webhook_chain, None, None, args) background_tasks.add_task(start_webhook_chain, db, None, None, args)
return schemas.Response(success=True) return schemas.Response(success=True)

View File

@ -235,11 +235,11 @@ def arr_movie_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> A
) )
tmdbid = term.replace("tmdb:", "") tmdbid = term.replace("tmdb:", "")
# 查询媒体信息 # 查询媒体信息
mediainfo = MediaChain().recognize_media(mtype=MediaType.MOVIE, tmdbid=int(tmdbid)) mediainfo = MediaChain(db).recognize_media(mtype=MediaType.MOVIE, tmdbid=int(tmdbid))
if not mediainfo: if not mediainfo:
return [RadarrMovie()] return [RadarrMovie()]
# 查询是否已存在 # 查询是否已存在
exists = MediaChain().media_exists(mediainfo=mediainfo) exists = MediaChain(db).media_exists(mediainfo=mediainfo)
if not exists: if not exists:
# 文件不存在 # 文件不存在
hasfile = False hasfile = False
@ -324,7 +324,7 @@ def arr_add_movie(apikey: str,
"id": subscribe.id "id": subscribe.id
} }
# 添加订阅 # 添加订阅
sid, message = SubscribeChain().add(title=movie.title, sid, message = SubscribeChain(db).add(title=movie.title,
year=movie.year, year=movie.year,
mtype=MediaType.MOVIE, mtype=MediaType.MOVIE,
tmdbid=movie.tmdbId, tmdbid=movie.tmdbId,
@ -515,7 +515,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) ->
# 获取TVDBID # 获取TVDBID
if not term.startswith("tvdb:"): if not term.startswith("tvdb:"):
mediainfo = MediaChain().recognize_media(meta=MetaInfo(term), mediainfo = MediaChain(db).recognize_media(meta=MetaInfo(term),
mtype=MediaType.TV) mtype=MediaType.TV)
if not mediainfo: if not mediainfo:
return [SonarrSeries()] return [SonarrSeries()]
@ -527,7 +527,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) ->
tvdbid = int(term.replace("tvdb:", "")) tvdbid = int(term.replace("tvdb:", ""))
# 查询TVDB信息 # 查询TVDB信息
tvdbinfo = MediaChain().tvdb_info(tvdbid=tvdbid) tvdbinfo = MediaChain(db).tvdb_info(tvdbid=tvdbid)
if not tvdbinfo: if not tvdbinfo:
return [SonarrSeries()] return [SonarrSeries()]
@ -539,11 +539,11 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) ->
# 根据TVDB查询媒体信息 # 根据TVDB查询媒体信息
if not mediainfo: if not mediainfo:
mediainfo = MediaChain().recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')), mediainfo = MediaChain(db).recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')),
mtype=MediaType.TV) mtype=MediaType.TV)
# 查询是否存在 # 查询是否存在
exists = MediaChain().media_exists(mediainfo) exists = MediaChain(db).media_exists(mediainfo)
if exists: if exists:
hasfile = True hasfile = True
else: else:
@ -666,7 +666,7 @@ def arr_add_series(apikey: str, tv: schemas.SonarrSeries,
for season in left_seasons: for season in left_seasons:
if not season.get("monitored"): if not season.get("monitored"):
continue continue
sid, message = SubscribeChain().add(title=tv.title, sid, message = SubscribeChain(db).add(title=tv.title,
year=tv.year, year=tv.year,
season=season.get("seasonNumber"), season=season.get("seasonNumber"),
tmdbid=tv.tmdbId, tmdbid=tv.tmdbId,

View File

@ -7,6 +7,7 @@ from typing import Optional, Any, Tuple, List, Set, Union, Dict
from qbittorrentapi import TorrentFilesList from qbittorrentapi import TorrentFilesList
from ruamel.yaml import CommentedMap from ruamel.yaml import CommentedMap
from sqlalchemy.orm import Session
from transmission_rpc import File from transmission_rpc import File
from app.core.config import settings from app.core.config import settings
@ -27,10 +28,11 @@ class ChainBase(metaclass=ABCMeta):
处理链基类 处理链基类
""" """
def __init__(self): def __init__(self, db: Session = None):
""" """
公共初始化 公共初始化
""" """
self._db = db
self.modulemanager = ModuleManager() self.modulemanager = ModuleManager()
self.eventmanager = EventManager() self.eventmanager = EventManager()

View File

@ -3,6 +3,7 @@ from typing import Tuple, Optional, Union
from urllib.parse import urljoin from urllib.parse import urljoin
from lxml import etree from lxml import etree
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.site import SiteChain from app.chain.site import SiteChain
@ -22,12 +23,12 @@ class CookieCloudChain(ChainBase):
CookieCloud处理链 CookieCloud处理链
""" """
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.siteoper = SiteOper() self.siteoper = SiteOper(self._db)
self.siteiconoper = SiteIconOper() self.siteiconoper = SiteIconOper(self._db)
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
self.sitechain = SiteChain() self.sitechain = SiteChain(self._db)
self.message = MessageHelper() self.message = MessageHelper()
self.cookiecloud = CookieCloudHelper( self.cookiecloud = CookieCloudHelper(
server=settings.COOKIECLOUD_HOST, server=settings.COOKIECLOUD_HOST,

View File

@ -2,6 +2,8 @@ import re
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Set, Dict, Union from typing import List, Optional, Tuple, Set, Dict, Union
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.core.context import MediaInfo, TorrentInfo, Context from app.core.context import MediaInfo, TorrentInfo, Context
@ -20,11 +22,11 @@ class DownloadChain(ChainBase):
下载处理链 下载处理链
""" """
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.torrent = TorrentHelper() self.torrent = TorrentHelper()
self.downloadhis = DownloadHistoryOper() self.downloadhis = DownloadHistoryOper(self._db)
self.mediaserver = MediaServerOper() self.mediaserver = MediaServerOper(self._db)
def post_download_message(self, meta: MetaBase, mediainfo: MediaInfo, torrent: TorrentInfo, def post_download_message(self, meta: MetaBase, mediainfo: MediaInfo, torrent: TorrentInfo,
channel: MessageChannel = None, channel: MessageChannel = None,

View File

@ -2,6 +2,8 @@ import json
import threading import threading
from typing import List, Union, Generator from typing import List, Union, Generator
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
@ -17,9 +19,9 @@ class MediaServerChain(ChainBase):
媒体服务器处理链 媒体服务器处理链
""" """
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.mediaserverdb = MediaServerOper() self.mediaserverdb = MediaServerOper(db)
def librarys(self) -> List[schemas.MediaServerLibrary]: def librarys(self) -> List[schemas.MediaServerLibrary]:
""" """

View File

@ -27,12 +27,12 @@ class MessageChain(ChainBase):
# 每页数据量 # 每页数据量
_page_size: int = 8 _page_size: int = 8
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.downloadchain = DownloadChain() self.downloadchain = DownloadChain(self._db)
self.subscribechain = SubscribeChain() self.subscribechain = SubscribeChain(self._db)
self.searchchain = SearchChain() self.searchchain = SearchChain(self._db)
self.medtachain = MediaChain() self.medtachain = MediaChain(self._db)
self.torrent = TorrentHelper() self.torrent = TorrentHelper()
self.eventmanager = EventManager() self.eventmanager = EventManager()
self.torrenthelper = TorrentHelper() self.torrenthelper = TorrentHelper()

View File

@ -4,6 +4,8 @@ import time
from datetime import datetime from datetime import datetime
from typing import Tuple, Optional from typing import Tuple, Optional
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.download import DownloadChain from app.chain.download import DownloadChain
from app.core.config import settings from app.core.config import settings
@ -25,12 +27,12 @@ class RssChain(ChainBase):
RSS处理链 RSS处理链
""" """
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.rssoper = RssOper() self.rssoper = RssOper(self._db)
self.sites = SitesHelper() self.sites = SitesHelper()
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper(self._db)
self.downloadchain = DownloadChain() self.downloadchain = DownloadChain(self._db)
self.message = MessageHelper() self.message = MessageHelper()
def add(self, title: str, year: str, def add(self, title: str, year: str,

View File

@ -4,6 +4,8 @@ from datetime import datetime
from typing import Dict from typing import Dict
from typing import List, Optional from typing import List, Optional
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.core.context import Context from app.core.context import Context
from app.core.context import MediaInfo, TorrentInfo from app.core.context import MediaInfo, TorrentInfo
@ -23,11 +25,11 @@ class SearchChain(ChainBase):
站点资源搜索处理链 站点资源搜索处理链
""" """
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
self.progress = ProgressHelper() self.progress = ProgressHelper()
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper(self._db)
self.torrenthelper = TorrentHelper() self.torrenthelper = TorrentHelper()
def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None) -> List[Context]: def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None) -> List[Context]:

View File

@ -1,5 +1,7 @@
from typing import Union, Tuple from typing import Union, Tuple
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.db.models.site import Site from app.db.models.site import Site
@ -20,9 +22,9 @@ class SiteChain(ChainBase):
站点管理处理链 站点管理处理链
""" """
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.siteoper = SiteOper() self.siteoper = SiteOper(self._db)
self.cookiehelper = CookieHelper() self.cookiehelper = CookieHelper()
self.message = MessageHelper() self.message = MessageHelper()

View File

@ -3,6 +3,8 @@ import re
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Tuple from typing import Dict, List, Optional, Union, Tuple
from requests import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.download import DownloadChain from app.chain.download import DownloadChain
from app.chain.search import SearchChain from app.chain.search import SearchChain
@ -27,14 +29,14 @@ class SubscribeChain(ChainBase):
_cache_file = "__torrents_cache__" _cache_file = "__torrents_cache__"
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.downloadchain = DownloadChain() self.downloadchain = DownloadChain(self._db)
self.searchchain = SearchChain() self.searchchain = SearchChain(self._db)
self.subscribehelper = SubscribeOper() self.subscribehelper = SubscribeOper(self._db)
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
self.message = MessageHelper() self.message = MessageHelper()
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper(self._db)
def add(self, title: str, year: str, def add(self, title: str, year: str,
mtype: MediaType = None, mtype: MediaType = None,

View File

@ -5,6 +5,8 @@ import threading
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.core.context import MediaInfo from app.core.context import MediaInfo
@ -28,10 +30,10 @@ class TransferChain(ChainBase):
文件转移处理链 文件转移处理链
""" """
def __init__(self): def __init__(self, db: Session = None):
super().__init__() super().__init__(db)
self.downloadhis = DownloadHistoryOper() self.downloadhis = DownloadHistoryOper(self._db)
self.transferhis = TransferHistoryOper() self.transferhis = TransferHistoryOper(self._db)
self.progress = ProgressHelper() self.progress = ProgressHelper()
def process(self, arg_str: str = None, channel: MessageChannel = None, userid: Union[str, int] = None) -> bool: def process(self, arg_str: str = None, channel: MessageChannel = None, userid: Union[str, int] = None) -> bool:

View File

@ -33,8 +33,11 @@ class DbOper:
_db: Session = None _db: Session = None
def __init__(self, _db=SessionLocal()): def __init__(self, db: Session = None):
self._db = _db if db:
self._db = db
else:
self._db = SessionLocal()
def __del__(self): def __del__(self):
if self._db: if self._db:

View File

@ -15,7 +15,7 @@ class DownloadHistoryOper(DbOper):
按路径查询下载记录 按路径查询下载记录
:param path: 数据key :param path: 数据key
""" """
return DownloadHistory.get_by_path(self._db, path) return DownloadHistory.get_by_path(self._db, str(path))
def get_by_hash(self, download_hash: str) -> Any: def get_by_hash(self, download_hash: str) -> Any:
""" """

View File

@ -1,7 +1,9 @@
import json import json
from typing import Optional from typing import Optional
from app.db import DbOper, SessionLocal from sqlalchemy.orm import Session
from app.db import DbOper
from app.db.models.mediaserver import MediaServerItem from app.db.models.mediaserver import MediaServerItem
@ -10,7 +12,7 @@ class MediaServerOper(DbOper):
媒体服务器数据管理 媒体服务器数据管理
""" """
def __init__(self, db=SessionLocal()): def __init__(self, db: Session = None):
super().__init__(db) super().__init__(db)
def add(self, **kwargs) -> bool: def add(self, **kwargs) -> bool:

View File

@ -1,6 +1,8 @@
from typing import List from typing import List
from app.db import DbOper, SessionLocal from sqlalchemy.orm import Session
from app.db import DbOper
from app.db.models.rss import Rss from app.db.models.rss import Rss
@ -9,7 +11,7 @@ class RssOper(DbOper):
RSS订阅数据管理 RSS订阅数据管理
""" """
def __init__(self, db=SessionLocal()): def __init__(self, db: Session = None):
super().__init__(db) super().__init__(db)
def add(self, **kwargs) -> bool: def add(self, **kwargs) -> bool:

View File

@ -1,22 +1,24 @@
import json import json
from typing import Any, Union from typing import Any, Union
from app.db import DbOper, SessionLocal from sqlalchemy.orm import Session
from app.db import DbOper
from app.db.models.systemconfig import SystemConfig from app.db.models.systemconfig import SystemConfig
from app.schemas.types import SystemConfigKey
from app.utils.object import ObjectUtils from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
from app.schemas.types import SystemConfigKey
class SystemConfigOper(DbOper, metaclass=Singleton): class SystemConfigOper(DbOper, metaclass=Singleton):
# 配置对象 # 配置对象
__SYSTEMCONF: dict = {} __SYSTEMCONF: dict = {}
def __init__(self, _db=SessionLocal()): def __init__(self, db: Session = None):
""" """
加载配置到内存 加载配置到内存
""" """
super().__init__(_db) super().__init__(db)
for item in SystemConfig.list(self._db): for item in SystemConfig.list(self._db):
if ObjectUtils.is_obj(item.value): if ObjectUtils.is_obj(item.value):
self.__SYSTEMCONF[item.key] = json.loads(item.value) self.__SYSTEMCONF[item.key] = json.loads(item.value)