From 5deec011f8e98c940ea618ff3ab899ca66122e2c Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 10 Aug 2023 16:50:02 +0800 Subject: [PATCH] fix rss api --- app/api/endpoints/rss.py | 44 ++++++++++++++++++++++++++++++++++++++++ app/chain/rss.py | 13 ++++++++++-- app/db/rss_oper.py | 4 +++- app/scheduler.py | 5 +++++ 4 files changed, 63 insertions(+), 3 deletions(-) diff --git a/app/api/endpoints/rss.py b/app/api/endpoints/rss.py index 5d0b9e87..b7a7191c 100644 --- a/app/api/endpoints/rss.py +++ b/app/api/endpoints/rss.py @@ -2,17 +2,26 @@ from typing import List, Any from fastapi import APIRouter, Depends from sqlalchemy.orm import Session +from starlette.background import BackgroundTasks from app import schemas from app.chain.rss import RssChain from app.core.security import verify_token from app.db import get_db from app.db.models.rss import Rss +from app.helper.rss import RssHelper from app.schemas import MediaType router = APIRouter() +def start_rss_refresh(rssid: int = None): + """ + 启动自定义订阅刷新 + """ + RssChain().refresh(rssid=rssid, manual=True) + + @router.get("/", summary="所有自定义订阅", response_model=List[schemas.Rss]) def read_rsses( db: Session = Depends(get_db), @@ -65,6 +74,41 @@ def update_rss( return schemas.Response(success=True) +@router.get("/preview/{rssid}", summary="预览自定义订阅", response_model=List[schemas.TorrentInfo]) +def preview_rss( + rssid: int, + db: Session = Depends(get_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: + """ + 根据ID查询自定义订阅RSS报文 + """ + rssinfo: Rss = Rss.get(db, rssid) + if not rssinfo: + return [] + torrents = RssHelper.parse(rssinfo.url, proxy=True if rssinfo.proxy else False) or [] + return [schemas.TorrentInfo( + title=t.get("title"), + description=t.get("description"), + enclosure=t.get("enclosure"), + size=t.get("size"), + page_url=t.get("link"), + pubdate=t["pubdate"].strftime("%Y-%m-%d %H:%M:%S") if t.get("pubdate") else None, + ) for t in torrents] + + +@router.get("/refresh/{rssid}", summary="刷新自定义订阅", response_model=schemas.Response) +def refresh_rss( + rssid: int, + background_tasks: BackgroundTasks, + _: schemas.TokenPayload = Depends(verify_token)) -> Any: + """ + 根据ID刷新自定义订阅 + """ + background_tasks.add_task(start_rss_refresh, + rssid=rssid) + return schemas.Response(success=True) + + @router.get("/{rssid}", summary="查询自定义订阅详情", response_model=schemas.Rss) def read_rss( rssid: int, diff --git a/app/chain/rss.py b/app/chain/rss.py index 3f256d0f..67f7c4f7 100644 --- a/app/chain/rss.py +++ b/app/chain/rss.py @@ -11,6 +11,7 @@ from app.core.context import Context, TorrentInfo, MediaInfo from app.core.metainfo import MetaInfo from app.db.rss_oper import RssOper from app.db.systemconfig_oper import SystemConfigOper +from app.helper.message import MessageHelper from app.helper.rss import RssHelper from app.helper.sites import SitesHelper from app.log import logger @@ -30,6 +31,7 @@ class RssChain(ChainBase): self.sites = SitesHelper() self.systemconfig = SystemConfigOper() self.downloadchain = DownloadChain() + self.message = MessageHelper() def add(self, title: str, year: str, mtype: MediaType = None, @@ -104,14 +106,16 @@ class RssChain(ChainBase): # 返回结果 return sid, "" - def refresh(self): + def refresh(self, rssid: int = None, manual: bool = False): """ 刷新RSS订阅数据 """ # 所有RSS订阅 logger.info("开始刷新RSS订阅数据 ...") - rss_tasks = self.rssoper.list() or [] + rss_tasks = self.rssoper.list(rssid) or [] for rss_task in rss_tasks: + if not rss_task: + continue if not rss_task.url: continue # 下载Rss报文 @@ -266,3 +270,8 @@ class RssChain(ChainBase): processed=(rss_task.processed or 0) + len(downloads), last_update=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) logger.info("刷新RSS订阅数据完成") + if manual: + if len(rss_tasks) == 1: + self.message.put(f"{rss_tasks[0].name} 自定义订阅刷新完成") + else: + self.message.put(f"自定义订阅刷新完成") diff --git a/app/db/rss_oper.py b/app/db/rss_oper.py index 98e16e8e..88abe070 100644 --- a/app/db/rss_oper.py +++ b/app/db/rss_oper.py @@ -26,10 +26,12 @@ class RssOper(DbOper): """ return Rss.get_by_tmdbid(self._db, tmdbid, season) - def list(self) -> List[Rss]: + def list(self, rssid: int = None) -> List[Rss]: """ 查询所有RSS订阅 """ + if rssid: + return [Rss.get(self._db, rssid)] return Rss.list(self._db) def delete(self, rssid: int) -> bool: diff --git a/app/scheduler.py b/app/scheduler.py index baab3fba..66a08d5c 100644 --- a/app/scheduler.py +++ b/app/scheduler.py @@ -8,6 +8,7 @@ from apscheduler.schedulers.background import BackgroundScheduler from app.chain import ChainBase from app.chain.cookiecloud import CookieCloudChain from app.chain.mediaserver import MediaServerChain +from app.chain.rss import RssChain from app.chain.subscribe import SubscribeChain from app.chain.transfer import TransferChain from app.core.config import settings @@ -69,6 +70,10 @@ class Scheduler(metaclass=Singleton): self._scheduler.add_job(SubscribeChain().refresh, "cron", hour=trigger.hour, minute=trigger.minute, name="订阅刷新") + # 自定义订阅 + self._scheduler.add_job(RssChain().refresh, "interval", + minutes=30, name="自定义订阅刷新") + # 下载器文件转移(每5分钟) self._scheduler.add_job(TransferChain().process, "interval", minutes=5, name="下载文件整理")