feat 订阅支持选择站点范围

This commit is contained in:
jxxghp 2023-07-13 13:00:44 +08:00
parent 670ddaf0a7
commit b58993798d
7 changed files with 97 additions and 12 deletions

19
alembic/gen.py Normal file
View File

@ -0,0 +1,19 @@
import importlib
from pathlib import Path
from alembic.config import Config as AlembicConfig
from alembic.command import revision as alembic_revision
from app.core.config import settings
# 导入模块,避免建表缺失
for module in Path(__file__).with_name("models").glob("*.py"):
importlib.import_module(f"app.db.models.{module.stem}")
db_version = input("请输入版本号:")
db_location = settings.CONFIG_PATH / 'user.db'
script_location = settings.ROOT_PATH / 'alembic'
alembic_cfg = AlembicConfig()
alembic_cfg.set_main_option('script_location', str(script_location))
alembic_cfg.set_main_option('sqlalchemy.url', f"sqlite:///{db_location}")
alembic_revision(alembic_cfg, db_version, True)

View File

@ -0,0 +1,32 @@
"""1.0.0
Revision ID: 9f4edd55c2d4
Revises:
Create Date: 2023-07-13 12:27:26.402317
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '9f4edd55c2d4'
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
try:
with op.batch_alter_table("subscribe") as batch_op:
batch_op.add_column(sa.Column('sites', sa.Text, nullable=True))
except Exception as e:
pass
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,3 +1,4 @@
import json
from typing import List, Any, Optional
from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header
@ -39,7 +40,11 @@ def read_subscribes(
"""
查询所有订阅
"""
return Subscribe.list(db)
subscribes = Subscribe.list(db)
for subscribe in subscribes:
if subscribe.sites:
subscribe.sites = json.loads(subscribe.sites)
return subscribes
@router.post("/", summary="新增订阅", response_model=schemas.Response)
@ -87,6 +92,8 @@ def update_subscribe(
subscribe = Subscribe.get(db, subscribe_in.id)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
if subscribe_in.sites:
subscribe_in.sites = json.dumps(subscribe_in.sites)
subscribe.update(db, subscribe_in.dict())
return schemas.Response(success=True)
@ -106,6 +113,8 @@ def subscribe_mediaid(
result = Subscribe.get_by_doubanid(db, mediaid[7:])
else:
result = None
if result:
result.sites = json.loads(result.sites)
return result if result else Subscribe()
@ -118,7 +127,10 @@ def read_subscribe(
"""
根据订阅编号查询订阅信息
"""
return Subscribe.get(db, subscribe_id)
subscribe = Subscribe.get(db, subscribe_id)
if subscribe.sites:
subscribe.sites = json.loads(subscribe.sites)
return subscribe
@router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response)
@ -153,8 +165,8 @@ def delete_subscribe(
@router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response)
def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
authorization: str = Header(None)) -> Any:
async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
authorization: str = Header(None)) -> Any:
"""
Jellyseerr/Overseerr订阅
"""
@ -163,7 +175,7 @@ def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
status_code=400,
detail="授权失败",
)
req_json = request.json()
req_json = await request.json()
if not req_json:
raise HTTPException(
status_code=500,

View File

@ -84,12 +84,14 @@ class SearchChain(ChainBase):
def process(self, mediainfo: MediaInfo,
keyword: str = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None) -> List[Context]:
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None) -> List[Context]:
"""
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
"""
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
@ -109,7 +111,8 @@ class SearchChain(ChainBase):
# 执行搜索
torrents: List[TorrentInfo] = self.__search_all_sites(
mediainfo=mediainfo,
keyword=keyword
keyword=keyword,
sites=sites
)
if not torrents:
logger.warn(f'{keyword or mediainfo.title} 未搜索到资源')
@ -199,17 +202,22 @@ class SearchChain(ChainBase):
return contexts
def __search_all_sites(self, mediainfo: Optional[MediaInfo] = None,
keyword: str = None) -> Optional[List[TorrentInfo]]:
keyword: str = None,
sites: List[int] = None) -> Optional[List[TorrentInfo]]:
"""
多线程搜索多个站点
:param mediainfo: 识别的媒体信息
:param keyword: 搜索关键词如有按关键词搜索否则按媒体信息名称搜索
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:reutrn: 资源列表
"""
# 未开启的站点不搜索
indexer_sites = []
# 配置的索引站点
config_indexers = [str(sid) for sid in self.systemconfig.get(SystemConfigKey.IndexerSites) or []]
if sites:
config_indexers = [str(sid) for sid in sites]
else:
config_indexers = [str(sid) for sid in self.systemconfig.get(SystemConfigKey.IndexerSites) or []]
for indexer in self.siteshelper.get_indexers():
# 检查站点索引开关
if not config_indexers or str(indexer.get("id")) in config_indexers:

View File

@ -201,10 +201,16 @@ class SubscribeChain(ChainBase):
start_episode=subscribe.start_episode,
)
# 站点范围
if subscribe.sites:
sites = json.loads(subscribe.sites)
else:
sites = None
# 搜索
contexts = self.searchchain.process(mediainfo=mediainfo,
keyword=subscribe.keyword,
no_exists=no_exists)
no_exists=no_exists,
sites=sites)
if not contexts:
logger.warn(f'订阅 {subscribe.keyword or subscribe.name} 未搜索到资源')
# 未搜索到资源,但本地缺失可能有变化,更新订阅剩余集数
@ -357,6 +363,11 @@ class SubscribeChain(ChainBase):
torrent_meta = context.meta_info
torrent_mediainfo = context.media_info
torrent_info = context.torrent_info
# 不在订阅站点范围的不处理
if subscribe.sites:
sub_sites = json.loads(subscribe.sites)
if sub_sites and torrent_info.site not in sub_sites:
continue
# 如果是电视剧过滤掉已经下载的集数
if torrent_mediainfo.type == MediaType.TV:
if self.__check_subscribe_note(subscribe, torrent_meta.episode_list):

View File

@ -2,7 +2,6 @@ from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db.models import Base
from app.schemas import MediaType
class Subscribe(Base):
@ -52,6 +51,8 @@ class Subscribe(Base):
last_update = Column(String)
# 订阅用户
username = Column(String)
# 订阅站点
sites = Column(String)
@staticmethod
def exists(db: Session, tmdbid: int, season: int = None):

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, List
from pydantic import BaseModel
@ -45,6 +45,8 @@ class Subscribe(BaseModel):
last_update: Optional[str] = None
# 订阅用户
username: Optional[str] = None
# 订阅站点
sites: Optional[List[int]] = None
class Config:
orm_mode = True