From b58993798d7491824ce7625fce4c61e09f4b9548 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 13 Jul 2023 13:00:44 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=E8=AE=A2=E9=98=85=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E9=80=89=E6=8B=A9=E7=AB=99=E7=82=B9=E8=8C=83=E5=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- alembic/gen.py | 19 +++++++++++++++ alembic/versions/9f4edd55c2d4_1_0_0.py | 32 ++++++++++++++++++++++++++ app/api/endpoints/subscribe.py | 22 ++++++++++++++---- app/chain/search.py | 16 +++++++++---- app/chain/subscribe.py | 13 ++++++++++- app/db/models/subscribe.py | 3 ++- app/schemas/subscribe.py | 4 +++- 7 files changed, 97 insertions(+), 12 deletions(-) create mode 100644 alembic/gen.py create mode 100644 alembic/versions/9f4edd55c2d4_1_0_0.py diff --git a/alembic/gen.py b/alembic/gen.py new file mode 100644 index 00000000..876b8525 --- /dev/null +++ b/alembic/gen.py @@ -0,0 +1,19 @@ +import importlib +from pathlib import Path + +from alembic.config import Config as AlembicConfig +from alembic.command import revision as alembic_revision + +from app.core.config import settings + +# 导入模块,避免建表缺失 +for module in Path(__file__).with_name("models").glob("*.py"): + importlib.import_module(f"app.db.models.{module.stem}") + +db_version = input("请输入版本号:") +db_location = settings.CONFIG_PATH / 'user.db' +script_location = settings.ROOT_PATH / 'alembic' +alembic_cfg = AlembicConfig() +alembic_cfg.set_main_option('script_location', str(script_location)) +alembic_cfg.set_main_option('sqlalchemy.url', f"sqlite:///{db_location}") +alembic_revision(alembic_cfg, db_version, True) diff --git a/alembic/versions/9f4edd55c2d4_1_0_0.py b/alembic/versions/9f4edd55c2d4_1_0_0.py new file mode 100644 index 00000000..6c29aaea --- /dev/null +++ b/alembic/versions/9f4edd55c2d4_1_0_0.py @@ -0,0 +1,32 @@ +"""1.0.0 + +Revision ID: 9f4edd55c2d4 +Revises: +Create Date: 2023-07-13 12:27:26.402317 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '9f4edd55c2d4' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + try: + with op.batch_alter_table("subscribe") as batch_op: + batch_op.add_column(sa.Column('sites', sa.Text, nullable=True)) + except Exception as e: + pass + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + pass + # ### end Alembic commands ### diff --git a/app/api/endpoints/subscribe.py b/app/api/endpoints/subscribe.py index 6e1b3d0a..281fcf6a 100644 --- a/app/api/endpoints/subscribe.py +++ b/app/api/endpoints/subscribe.py @@ -1,3 +1,4 @@ +import json from typing import List, Any, Optional from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header @@ -39,7 +40,11 @@ def read_subscribes( """ 查询所有订阅 """ - return Subscribe.list(db) + subscribes = Subscribe.list(db) + for subscribe in subscribes: + if subscribe.sites: + subscribe.sites = json.loads(subscribe.sites) + return subscribes @router.post("/", summary="新增订阅", response_model=schemas.Response) @@ -87,6 +92,8 @@ def update_subscribe( subscribe = Subscribe.get(db, subscribe_in.id) if not subscribe: return schemas.Response(success=False, message="订阅不存在") + if subscribe_in.sites: + subscribe_in.sites = json.dumps(subscribe_in.sites) subscribe.update(db, subscribe_in.dict()) return schemas.Response(success=True) @@ -106,6 +113,8 @@ def subscribe_mediaid( result = Subscribe.get_by_doubanid(db, mediaid[7:]) else: result = None + if result: + result.sites = json.loads(result.sites) return result if result else Subscribe() @@ -118,7 +127,10 @@ def read_subscribe( """ 根据订阅编号查询订阅信息 """ - return Subscribe.get(db, subscribe_id) + subscribe = Subscribe.get(db, subscribe_id) + if subscribe.sites: + subscribe.sites = json.loads(subscribe.sites) + return subscribe @router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response) @@ -153,8 +165,8 @@ def delete_subscribe( @router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response) -def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, - authorization: str = Header(None)) -> Any: +async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, + authorization: str = Header(None)) -> Any: """ Jellyseerr/Overseerr订阅 """ @@ -163,7 +175,7 @@ def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, status_code=400, detail="授权失败", ) - req_json = request.json() + req_json = await request.json() if not req_json: raise HTTPException( status_code=500, diff --git a/app/chain/search.py b/app/chain/search.py index 45bd953e..ee83efd9 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -84,12 +84,14 @@ class SearchChain(ChainBase): def process(self, mediainfo: MediaInfo, keyword: str = None, - no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None) -> List[Context]: + no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None, + sites: List[int] = None) -> List[Context]: """ 根据媒体信息搜索种子资源,精确匹配,应用过滤规则,同时根据no_exists过滤本地已存在的资源 :param mediainfo: 媒体信息 :param keyword: 搜索关键词 :param no_exists: 缺失的媒体信息 + :param sites: 站点ID列表,为空时搜索所有站点 """ logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...') # 补充媒体信息 @@ -109,7 +111,8 @@ class SearchChain(ChainBase): # 执行搜索 torrents: List[TorrentInfo] = self.__search_all_sites( mediainfo=mediainfo, - keyword=keyword + keyword=keyword, + sites=sites ) if not torrents: logger.warn(f'{keyword or mediainfo.title} 未搜索到资源') @@ -199,17 +202,22 @@ class SearchChain(ChainBase): return contexts def __search_all_sites(self, mediainfo: Optional[MediaInfo] = None, - keyword: str = None) -> Optional[List[TorrentInfo]]: + keyword: str = None, + sites: List[int] = None) -> Optional[List[TorrentInfo]]: """ 多线程搜索多个站点 :param mediainfo: 识别的媒体信息 :param keyword: 搜索关键词,如有按关键词搜索,否则按媒体信息名称搜索 + :param sites: 指定站点ID列表,如有则只搜索指定站点,否则搜索所有站点 :reutrn: 资源列表 """ # 未开启的站点不搜索 indexer_sites = [] # 配置的索引站点 - config_indexers = [str(sid) for sid in self.systemconfig.get(SystemConfigKey.IndexerSites) or []] + if sites: + config_indexers = [str(sid) for sid in sites] + else: + config_indexers = [str(sid) for sid in self.systemconfig.get(SystemConfigKey.IndexerSites) or []] for indexer in self.siteshelper.get_indexers(): # 检查站点索引开关 if not config_indexers or str(indexer.get("id")) in config_indexers: diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index f2d82f0a..a63d7ad9 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -201,10 +201,16 @@ class SubscribeChain(ChainBase): start_episode=subscribe.start_episode, ) + # 站点范围 + if subscribe.sites: + sites = json.loads(subscribe.sites) + else: + sites = None # 搜索 contexts = self.searchchain.process(mediainfo=mediainfo, keyword=subscribe.keyword, - no_exists=no_exists) + no_exists=no_exists, + sites=sites) if not contexts: logger.warn(f'订阅 {subscribe.keyword or subscribe.name} 未搜索到资源') # 未搜索到资源,但本地缺失可能有变化,更新订阅剩余集数 @@ -357,6 +363,11 @@ class SubscribeChain(ChainBase): torrent_meta = context.meta_info torrent_mediainfo = context.media_info torrent_info = context.torrent_info + # 不在订阅站点范围的不处理 + if subscribe.sites: + sub_sites = json.loads(subscribe.sites) + if sub_sites and torrent_info.site not in sub_sites: + continue # 如果是电视剧过滤掉已经下载的集数 if torrent_mediainfo.type == MediaType.TV: if self.__check_subscribe_note(subscribe, torrent_meta.episode_list): diff --git a/app/db/models/subscribe.py b/app/db/models/subscribe.py index 0eaa3024..385425e6 100644 --- a/app/db/models/subscribe.py +++ b/app/db/models/subscribe.py @@ -2,7 +2,6 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session from app.db.models import Base -from app.schemas import MediaType class Subscribe(Base): @@ -52,6 +51,8 @@ class Subscribe(Base): last_update = Column(String) # 订阅用户 username = Column(String) + # 订阅站点 + sites = Column(String) @staticmethod def exists(db: Session, tmdbid: int, season: int = None): diff --git a/app/schemas/subscribe.py b/app/schemas/subscribe.py index 75e8252c..4dd42a40 100644 --- a/app/schemas/subscribe.py +++ b/app/schemas/subscribe.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List from pydantic import BaseModel @@ -45,6 +45,8 @@ class Subscribe(BaseModel): last_update: Optional[str] = None # 订阅用户 username: Optional[str] = None + # 订阅站点 + sites: Optional[List[int]] = None class Config: orm_mode = True