From e10776cf1d88a8db403732cdfa167b2f4595614b Mon Sep 17 00:00:00 2001 From: jxxghp Date: Mon, 12 Jun 2023 09:27:39 +0800 Subject: [PATCH] =?UTF-8?q?add=20=E7=AB=99=E7=82=B9=E3=80=81=E8=AE=A2?= =?UTF-8?q?=E9=98=85API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/endpoints/douban.py | 9 ++---- app/api/endpoints/media.py | 9 ++---- app/api/endpoints/sites.py | 38 +++++++++++++--------- app/api/endpoints/subscribes.py | 56 ++++++++++++++++++++++----------- app/db/models/subscribe.py | 4 +++ 5 files changed, 69 insertions(+), 47 deletions(-) diff --git a/app/api/endpoints/douban.py b/app/api/endpoints/douban.py index dfbea1e7..46a3c2ca 100644 --- a/app/api/endpoints/douban.py +++ b/app/api/endpoints/douban.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks +from fastapi import APIRouter, Depends, BackgroundTasks from app import schemas from app.chain.douban_sync import DoubanSyncChain @@ -18,14 +18,9 @@ def start_douban_chain(): @router.get("/sync", response_model=schemas.Response) async def sync_douban( background_tasks: BackgroundTasks, - current_user: User = Depends(get_current_active_superuser)): + _: User = Depends(get_current_active_superuser)): """ 查询所有订阅 """ - if not current_user: - raise HTTPException( - status_code=400, - detail="需要授权", - ) background_tasks.add_task(start_douban_chain) return {"success": True} diff --git a/app/api/endpoints/media.py b/app/api/endpoints/media.py index 05c08abc..3ec888a7 100644 --- a/app/api/endpoints/media.py +++ b/app/api/endpoints/media.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException, Depends +from fastapi import APIRouter, Depends from app import schemas from app.chain.identify import IdentifyChain @@ -11,15 +11,10 @@ router = APIRouter() @router.post("/recognize", response_model=schemas.Context) async def recognize(title: str, subtitle: str = None, - current_user: User = Depends(get_current_active_user)): + _: User = Depends(get_current_active_user)): """ 识别媒体信息 """ - if not current_user: - raise HTTPException( - status_code=400, - detail="需要授权", - ) # 识别媒体信息 context = IdentifyChain().process(title=title, subtitle=subtitle) return context.to_dict() diff --git a/app/api/endpoints/sites.py b/app/api/endpoints/sites.py index f04f3b4e..08f86269 100644 --- a/app/api/endpoints/sites.py +++ b/app/api/endpoints/sites.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session @@ -8,35 +8,45 @@ from app.chain.cookiecloud import CookieCloudChain from app.db import get_db from app.db.models.site import Site from app.db.models.user import User -from app.db.userauth import get_current_active_user +from app.db.userauth import get_current_active_user, get_current_active_superuser router = APIRouter() @router.get("/", response_model=List[schemas.Site]) async def read_sites(db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user)) -> List[dict]: + _: User = Depends(get_current_active_user)) -> List[dict]: """ 获取站点列表 """ - if not current_user: - raise HTTPException( - status_code=400, - detail="需要授权", - ) return Site.list(db) +@router.post("/update", response_model=schemas.Site) +async def update_site( + *, + db: Session = Depends(get_db), + site_in: schemas.Site, + _: User = Depends(get_current_active_superuser), +) -> Any: + """ + 更新站点信息 + """ + site = Site.get(db, site_in.id) + if not site: + raise HTTPException( + status_code=404, + detail=f"站点 {site_in.id} 不存在", + ) + site.update(db, **site_in.dict()) + return site + + @router.get("/cookiecloud", response_model=schemas.Response) -async def cookie_cloud_sync(current_user: User = Depends(get_current_active_user)) -> dict: +async def cookie_cloud_sync(_: User = Depends(get_current_active_user)) -> dict: """ 运行CookieCloud同步站点信息 """ - if not current_user: - raise HTTPException( - status_code=400, - detail="需要授权", - ) status, error_msg = CookieCloudChain().process() if not status: return {"success": False, "message": error_msg} diff --git a/app/api/endpoints/subscribes.py b/app/api/endpoints/subscribes.py index e41d85b2..63f7c0e8 100644 --- a/app/api/endpoints/subscribes.py +++ b/app/api/endpoints/subscribes.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Any from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header from sqlalchemy.orm import Session @@ -27,18 +27,46 @@ def start_subscribe_chain(title: str, @router.get("/", response_model=List[schemas.Subscribe]) async def read_subscribes( db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser)): + _: User = Depends(get_current_active_superuser)): """ 查询所有订阅 """ - if not current_user: - raise HTTPException( - status_code=400, - detail="需要授权", - ) return Subscribe.list(db) +@router.post("/", response_model=schemas.Response) +async def create_subscribe( + *, + subscribe_in: schemas.Subscribe, + _: User = Depends(get_current_active_superuser), +) -> Any: + """ + 新增订阅 + """ + result = SubscribeChain().process(**subscribe_in.dict()) + return {"success": result} + + +@router.post("/update", response_model=schemas.Subscribe) +async def update_subscribe( + *, + db: Session = Depends(get_db), + subscribe_in: schemas.Subscribe, + _: User = Depends(get_current_active_superuser), +) -> Any: + """ + 更新订阅信息 + """ + subscribe = Subscribe.get(db, subscribe_in.id) + if not subscribe: + raise HTTPException( + status_code=404, + detail=f"订阅 {subscribe_in.id} 不存在", + ) + subscribe.update(db, **subscribe_in.dict()) + return subscribe + + @router.post("/seerr", response_model=schemas.Response) async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, authorization: str = Header(None)): @@ -92,29 +120,19 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, @router.get("/refresh", response_model=schemas.Response) async def refresh_subscribes( - current_user: User = Depends(get_current_active_superuser)): + _: User = Depends(get_current_active_superuser)): """ 刷新所有订阅 """ - if not current_user: - raise HTTPException( - status_code=400, - detail="需要授权", - ) SubscribeChain().refresh() return {"success": True} @router.get("/search", response_model=schemas.Response) async def search_subscribes( - current_user: User = Depends(get_current_active_superuser)): + _: User = Depends(get_current_active_superuser)): """ 搜索所有订阅 """ - if not current_user: - raise HTTPException( - status_code=400, - detail="需要授权", - ) SubscribeChain().search(state='R') return {"success": True} diff --git a/app/db/models/subscribe.py b/app/db/models/subscribe.py index 64beb98c..40bcf008 100644 --- a/app/db/models/subscribe.py +++ b/app/db/models/subscribe.py @@ -37,3 +37,7 @@ class Subscribe(Base): @staticmethod def get_by_state(db: Session, state: str): return db.query(Subscribe).filter(Subscribe.state == state).all() + + @staticmethod + def get_by_tmdbid(db: Session, tmdbid: str): + return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first()