add 站点、订阅API

This commit is contained in:
jxxghp 2023-06-12 09:27:39 +08:00
parent 6e8b687545
commit e10776cf1d
5 changed files with 69 additions and 47 deletions

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks from fastapi import APIRouter, Depends, BackgroundTasks
from app import schemas from app import schemas
from app.chain.douban_sync import DoubanSyncChain from app.chain.douban_sync import DoubanSyncChain
@ -18,14 +18,9 @@ def start_douban_chain():
@router.get("/sync", response_model=schemas.Response) @router.get("/sync", response_model=schemas.Response)
async def sync_douban( async def sync_douban(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
current_user: User = Depends(get_current_active_superuser)): _: User = Depends(get_current_active_superuser)):
""" """
查询所有订阅 查询所有订阅
""" """
if not current_user:
raise HTTPException(
status_code=400,
detail="需要授权",
)
background_tasks.add_task(start_douban_chain) background_tasks.add_task(start_douban_chain)
return {"success": True} return {"success": True}

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, HTTPException, Depends from fastapi import APIRouter, Depends
from app import schemas from app import schemas
from app.chain.identify import IdentifyChain from app.chain.identify import IdentifyChain
@ -11,15 +11,10 @@ router = APIRouter()
@router.post("/recognize", response_model=schemas.Context) @router.post("/recognize", response_model=schemas.Context)
async def recognize(title: str, async def recognize(title: str,
subtitle: str = None, subtitle: str = None,
current_user: User = Depends(get_current_active_user)): _: User = Depends(get_current_active_user)):
""" """
识别媒体信息 识别媒体信息
""" """
if not current_user:
raise HTTPException(
status_code=400,
detail="需要授权",
)
# 识别媒体信息 # 识别媒体信息
context = IdentifyChain().process(title=title, subtitle=subtitle) context = IdentifyChain().process(title=title, subtitle=subtitle)
return context.to_dict() return context.to_dict()

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Any
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -8,35 +8,45 @@ from app.chain.cookiecloud import CookieCloudChain
from app.db import get_db from app.db import get_db
from app.db.models.site import Site from app.db.models.site import Site
from app.db.models.user import User from app.db.models.user import User
from app.db.userauth import get_current_active_user from app.db.userauth import get_current_active_user, get_current_active_superuser
router = APIRouter() router = APIRouter()
@router.get("/", response_model=List[schemas.Site]) @router.get("/", response_model=List[schemas.Site])
async def read_sites(db: Session = Depends(get_db), async def read_sites(db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)) -> List[dict]: _: User = Depends(get_current_active_user)) -> List[dict]:
""" """
获取站点列表 获取站点列表
""" """
if not current_user:
raise HTTPException(
status_code=400,
detail="需要授权",
)
return Site.list(db) return Site.list(db)
@router.post("/update", response_model=schemas.Site)
async def update_site(
*,
db: Session = Depends(get_db),
site_in: schemas.Site,
_: User = Depends(get_current_active_superuser),
) -> Any:
"""
更新站点信息
"""
site = Site.get(db, site_in.id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_in.id} 不存在",
)
site.update(db, **site_in.dict())
return site
@router.get("/cookiecloud", response_model=schemas.Response) @router.get("/cookiecloud", response_model=schemas.Response)
async def cookie_cloud_sync(current_user: User = Depends(get_current_active_user)) -> dict: async def cookie_cloud_sync(_: User = Depends(get_current_active_user)) -> dict:
""" """
运行CookieCloud同步站点信息 运行CookieCloud同步站点信息
""" """
if not current_user:
raise HTTPException(
status_code=400,
detail="需要授权",
)
status, error_msg = CookieCloudChain().process() status, error_msg = CookieCloudChain().process()
if not status: if not status:
return {"success": False, "message": error_msg} return {"success": False, "message": error_msg}

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Any
from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -27,18 +27,46 @@ def start_subscribe_chain(title: str,
@router.get("/", response_model=List[schemas.Subscribe]) @router.get("/", response_model=List[schemas.Subscribe])
async def read_subscribes( async def read_subscribes(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_superuser)): _: User = Depends(get_current_active_superuser)):
""" """
查询所有订阅 查询所有订阅
""" """
if not current_user:
raise HTTPException(
status_code=400,
detail="需要授权",
)
return Subscribe.list(db) return Subscribe.list(db)
@router.post("/", response_model=schemas.Response)
async def create_subscribe(
*,
subscribe_in: schemas.Subscribe,
_: User = Depends(get_current_active_superuser),
) -> Any:
"""
新增订阅
"""
result = SubscribeChain().process(**subscribe_in.dict())
return {"success": result}
@router.post("/update", response_model=schemas.Subscribe)
async def update_subscribe(
*,
db: Session = Depends(get_db),
subscribe_in: schemas.Subscribe,
_: User = Depends(get_current_active_superuser),
) -> Any:
"""
更新订阅信息
"""
subscribe = Subscribe.get(db, subscribe_in.id)
if not subscribe:
raise HTTPException(
status_code=404,
detail=f"订阅 {subscribe_in.id} 不存在",
)
subscribe.update(db, **subscribe_in.dict())
return subscribe
@router.post("/seerr", response_model=schemas.Response) @router.post("/seerr", response_model=schemas.Response)
async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
authorization: str = Header(None)): authorization: str = Header(None)):
@ -92,29 +120,19 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
@router.get("/refresh", response_model=schemas.Response) @router.get("/refresh", response_model=schemas.Response)
async def refresh_subscribes( async def refresh_subscribes(
current_user: User = Depends(get_current_active_superuser)): _: User = Depends(get_current_active_superuser)):
""" """
刷新所有订阅 刷新所有订阅
""" """
if not current_user:
raise HTTPException(
status_code=400,
detail="需要授权",
)
SubscribeChain().refresh() SubscribeChain().refresh()
return {"success": True} return {"success": True}
@router.get("/search", response_model=schemas.Response) @router.get("/search", response_model=schemas.Response)
async def search_subscribes( async def search_subscribes(
current_user: User = Depends(get_current_active_superuser)): _: User = Depends(get_current_active_superuser)):
""" """
搜索所有订阅 搜索所有订阅
""" """
if not current_user:
raise HTTPException(
status_code=400,
detail="需要授权",
)
SubscribeChain().search(state='R') SubscribeChain().search(state='R')
return {"success": True} return {"success": True}

View File

@ -37,3 +37,7 @@ class Subscribe(Base):
@staticmethod @staticmethod
def get_by_state(db: Session, state: str): def get_by_state(db: Session, state: str):
return db.query(Subscribe).filter(Subscribe.state == state).all() return db.query(Subscribe).filter(Subscribe.state == state).all()
@staticmethod
def get_by_tmdbid(db: Session, tmdbid: str):
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first()