feat 订阅支持选择站点范围

This commit is contained in:
jxxghp 2023-07-13 13:00:44 +08:00
parent 670ddaf0a7
commit b58993798d
7 changed files with 97 additions and 12 deletions

19
alembic/gen.py Normal file
View File

@ -0,0 +1,19 @@
import importlib
from pathlib import Path
from alembic.config import Config as AlembicConfig
from alembic.command import revision as alembic_revision
from app.core.config import settings
# 导入模块,避免建表缺失
for module in Path(__file__).with_name("models").glob("*.py"):
importlib.import_module(f"app.db.models.{module.stem}")
db_version = input("请输入版本号:")
db_location = settings.CONFIG_PATH / 'user.db'
script_location = settings.ROOT_PATH / 'alembic'
alembic_cfg = AlembicConfig()
alembic_cfg.set_main_option('script_location', str(script_location))
alembic_cfg.set_main_option('sqlalchemy.url', f"sqlite:///{db_location}")
alembic_revision(alembic_cfg, db_version, True)

View File

@ -0,0 +1,32 @@
"""1.0.0
Revision ID: 9f4edd55c2d4
Revises:
Create Date: 2023-07-13 12:27:26.402317
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '9f4edd55c2d4'
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
try:
with op.batch_alter_table("subscribe") as batch_op:
batch_op.add_column(sa.Column('sites', sa.Text, nullable=True))
except Exception as e:
pass
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,3 +1,4 @@
import json
from typing import List, Any, Optional from typing import List, Any, Optional
from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header
@ -39,7 +40,11 @@ def read_subscribes(
""" """
查询所有订阅 查询所有订阅
""" """
return Subscribe.list(db) subscribes = Subscribe.list(db)
for subscribe in subscribes:
if subscribe.sites:
subscribe.sites = json.loads(subscribe.sites)
return subscribes
@router.post("/", summary="新增订阅", response_model=schemas.Response) @router.post("/", summary="新增订阅", response_model=schemas.Response)
@ -87,6 +92,8 @@ def update_subscribe(
subscribe = Subscribe.get(db, subscribe_in.id) subscribe = Subscribe.get(db, subscribe_in.id)
if not subscribe: if not subscribe:
return schemas.Response(success=False, message="订阅不存在") return schemas.Response(success=False, message="订阅不存在")
if subscribe_in.sites:
subscribe_in.sites = json.dumps(subscribe_in.sites)
subscribe.update(db, subscribe_in.dict()) subscribe.update(db, subscribe_in.dict())
return schemas.Response(success=True) return schemas.Response(success=True)
@ -106,6 +113,8 @@ def subscribe_mediaid(
result = Subscribe.get_by_doubanid(db, mediaid[7:]) result = Subscribe.get_by_doubanid(db, mediaid[7:])
else: else:
result = None result = None
if result:
result.sites = json.loads(result.sites)
return result if result else Subscribe() return result if result else Subscribe()
@ -118,7 +127,10 @@ def read_subscribe(
""" """
根据订阅编号查询订阅信息 根据订阅编号查询订阅信息
""" """
return Subscribe.get(db, subscribe_id) subscribe = Subscribe.get(db, subscribe_id)
if subscribe.sites:
subscribe.sites = json.loads(subscribe.sites)
return subscribe
@router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response) @router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response)
@ -153,8 +165,8 @@ def delete_subscribe(
@router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response) @router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response)
def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
authorization: str = Header(None)) -> Any: authorization: str = Header(None)) -> Any:
""" """
Jellyseerr/Overseerr订阅 Jellyseerr/Overseerr订阅
""" """
@ -163,7 +175,7 @@ def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
status_code=400, status_code=400,
detail="授权失败", detail="授权失败",
) )
req_json = request.json() req_json = await request.json()
if not req_json: if not req_json:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,

View File

@ -84,12 +84,14 @@ class SearchChain(ChainBase):
def process(self, mediainfo: MediaInfo, def process(self, mediainfo: MediaInfo,
keyword: str = None, keyword: str = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None) -> List[Context]: no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None) -> List[Context]:
""" """
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源 根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息 :param mediainfo: 媒体信息
:param keyword: 搜索关键词 :param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息 :param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
""" """
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...') logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息 # 补充媒体信息
@ -109,7 +111,8 @@ class SearchChain(ChainBase):
# 执行搜索 # 执行搜索
torrents: List[TorrentInfo] = self.__search_all_sites( torrents: List[TorrentInfo] = self.__search_all_sites(
mediainfo=mediainfo, mediainfo=mediainfo,
keyword=keyword keyword=keyword,
sites=sites
) )
if not torrents: if not torrents:
logger.warn(f'{keyword or mediainfo.title} 未搜索到资源') logger.warn(f'{keyword or mediainfo.title} 未搜索到资源')
@ -199,17 +202,22 @@ class SearchChain(ChainBase):
return contexts return contexts
def __search_all_sites(self, mediainfo: Optional[MediaInfo] = None, def __search_all_sites(self, mediainfo: Optional[MediaInfo] = None,
keyword: str = None) -> Optional[List[TorrentInfo]]: keyword: str = None,
sites: List[int] = None) -> Optional[List[TorrentInfo]]:
""" """
多线程搜索多个站点 多线程搜索多个站点
:param mediainfo: 识别的媒体信息 :param mediainfo: 识别的媒体信息
:param keyword: 搜索关键词如有按关键词搜索否则按媒体信息名称搜索 :param keyword: 搜索关键词如有按关键词搜索否则按媒体信息名称搜索
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:reutrn: 资源列表 :reutrn: 资源列表
""" """
# 未开启的站点不搜索 # 未开启的站点不搜索
indexer_sites = [] indexer_sites = []
# 配置的索引站点 # 配置的索引站点
config_indexers = [str(sid) for sid in self.systemconfig.get(SystemConfigKey.IndexerSites) or []] if sites:
config_indexers = [str(sid) for sid in sites]
else:
config_indexers = [str(sid) for sid in self.systemconfig.get(SystemConfigKey.IndexerSites) or []]
for indexer in self.siteshelper.get_indexers(): for indexer in self.siteshelper.get_indexers():
# 检查站点索引开关 # 检查站点索引开关
if not config_indexers or str(indexer.get("id")) in config_indexers: if not config_indexers or str(indexer.get("id")) in config_indexers:

View File

@ -201,10 +201,16 @@ class SubscribeChain(ChainBase):
start_episode=subscribe.start_episode, start_episode=subscribe.start_episode,
) )
# 站点范围
if subscribe.sites:
sites = json.loads(subscribe.sites)
else:
sites = None
# 搜索 # 搜索
contexts = self.searchchain.process(mediainfo=mediainfo, contexts = self.searchchain.process(mediainfo=mediainfo,
keyword=subscribe.keyword, keyword=subscribe.keyword,
no_exists=no_exists) no_exists=no_exists,
sites=sites)
if not contexts: if not contexts:
logger.warn(f'订阅 {subscribe.keyword or subscribe.name} 未搜索到资源') logger.warn(f'订阅 {subscribe.keyword or subscribe.name} 未搜索到资源')
# 未搜索到资源,但本地缺失可能有变化,更新订阅剩余集数 # 未搜索到资源,但本地缺失可能有变化,更新订阅剩余集数
@ -357,6 +363,11 @@ class SubscribeChain(ChainBase):
torrent_meta = context.meta_info torrent_meta = context.meta_info
torrent_mediainfo = context.media_info torrent_mediainfo = context.media_info
torrent_info = context.torrent_info torrent_info = context.torrent_info
# 不在订阅站点范围的不处理
if subscribe.sites:
sub_sites = json.loads(subscribe.sites)
if sub_sites and torrent_info.site not in sub_sites:
continue
# 如果是电视剧过滤掉已经下载的集数 # 如果是电视剧过滤掉已经下载的集数
if torrent_mediainfo.type == MediaType.TV: if torrent_mediainfo.type == MediaType.TV:
if self.__check_subscribe_note(subscribe, torrent_meta.episode_list): if self.__check_subscribe_note(subscribe, torrent_meta.episode_list):

View File

@ -2,7 +2,6 @@ from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db.models import Base from app.db.models import Base
from app.schemas import MediaType
class Subscribe(Base): class Subscribe(Base):
@ -52,6 +51,8 @@ class Subscribe(Base):
last_update = Column(String) last_update = Column(String)
# 订阅用户 # 订阅用户
username = Column(String) username = Column(String)
# 订阅站点
sites = Column(String)
@staticmethod @staticmethod
def exists(db: Session, tmdbid: int, season: int = None): def exists(db: Session, tmdbid: int, season: int = None):

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, List
from pydantic import BaseModel from pydantic import BaseModel
@ -45,6 +45,8 @@ class Subscribe(BaseModel):
last_update: Optional[str] = None last_update: Optional[str] = None
# 订阅用户 # 订阅用户
username: Optional[str] = None username: Optional[str] = None
# 订阅站点
sites: Optional[List[int]] = None
class Config: class Config:
orm_mode = True orm_mode = True