fix subscribe

This commit is contained in:
jxxghp 2023-07-04 15:56:13 +08:00
parent 9f18d4a1df
commit d97bc41ca7

View File

@ -1,4 +1,4 @@
from typing import List, Any from typing import List, Any, Optional
from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -16,15 +16,22 @@ from app.schemas.types import MediaType
router = APIRouter() router = APIRouter()
def start_subscribe_chain(title: str, year: str, def start_subscribe_add(title: str, year: str,
mtype: MediaType, tmdbid: int, season: int, username: str): mtype: MediaType, tmdbid: int, season: int, username: str):
""" """
启动订阅链式任务 启动订阅任务
""" """
SubscribeChain().add(title=title, year=year, SubscribeChain().add(title=title, year=year,
mtype=mtype, tmdbid=tmdbid, season=season, username=username) mtype=mtype, tmdbid=tmdbid, season=season, username=username)
def start_subscribe_search(sid: Optional[int], state: Optional[str]):
"""
启动订阅搜索任务
"""
SubscribeChain().search(sid=sid, state=state)
@router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe]) @router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe])
async def read_subscribes( async def read_subscribes(
db: Session = Depends(get_db), db: Session = Depends(get_db),
@ -176,7 +183,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
user_name = req_json.get("request", {}).get("requestedBy_username") user_name = req_json.get("request", {}).get("requestedBy_username")
# 添加订阅 # 添加订阅
if media_type == MediaType.MOVIE: if media_type == MediaType.MOVIE:
background_tasks.add_task(start_subscribe_chain, background_tasks.add_task(start_subscribe_add,
mtype=media_type, mtype=media_type,
tmdbid=tmdbId, tmdbid=tmdbId,
title=subject, title=subject,
@ -190,7 +197,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
seasons = [int(str(sea).strip()) for sea in extra.get("value").split(", ") if str(sea).isdigit()] seasons = [int(str(sea).strip()) for sea in extra.get("value").split(", ") if str(sea).isdigit()]
break break
for season in seasons: for season in seasons:
background_tasks.add_task(start_subscribe_chain, background_tasks.add_task(start_subscribe_add,
mtype=media_type, mtype=media_type,
tmdbid=tmdbId, tmdbid=tmdbId,
title=subject, title=subject,
@ -211,11 +218,24 @@ async def refresh_subscribes(
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/search", summary="搜索订阅", response_model=schemas.Response) @router.get("/search/{subscribe_id}", summary="搜索订阅", response_model=schemas.Response)
async def search_subscribes( async def search_subscribe(
subscribe_id: int,
background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
搜索所有订阅 搜索所有订阅
""" """
SubscribeChain().search(state='R') background_tasks.add_task(start_subscribe_search, sid=subscribe_id, state=None)
return schemas.Response(success=True)
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
async def search_subscribes(
background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
搜索所有订阅
"""
background_tasks.add_task(start_subscribe_search, sid=None, state='R')
return schemas.Response(success=True) return schemas.Response(success=True)