fix 订阅刷新站点范围扩大问题

This commit is contained in:
jxxghp 2024-05-24 12:31:11 +08:00
parent a2b0c9bd3a
commit 27b4f206a1
4 changed files with 32 additions and 2 deletions

View File

@ -16,6 +16,7 @@ from app.core.event import eventmanager, Event, EventManager
from app.core.meta import MetaBase from app.core.meta import MetaBase
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.db.models.subscribe import Subscribe from app.db.models.subscribe import Subscribe
from app.db.site_oper import SiteOper
from app.db.subscribe_oper import SubscribeOper from app.db.subscribe_oper import SubscribeOper
from app.db.subscribehistory_oper import SubscribeHistoryOper from app.db.subscribehistory_oper import SubscribeHistoryOper
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
@ -44,6 +45,7 @@ class SubscribeChain(ChainBase):
self.message = MessageHelper() self.message = MessageHelper()
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper()
self.torrenthelper = TorrentHelper() self.torrenthelper = TorrentHelper()
self.siteoper = SiteOper()
def add(self, title: str, year: str, def add(self, title: str, year: str,
mtype: MediaType = None, mtype: MediaType = None,
@ -525,6 +527,15 @@ class SubscribeChain(ChainBase):
meta.year = subscribe.year meta.year = subscribe.year
meta.begin_season = subscribe.season or None meta.begin_season = subscribe.season or None
meta.type = MediaType(subscribe.type) meta.type = MediaType(subscribe.type)
# 订阅的站点域名列表
domains = []
if subscribe.sites:
try:
siteids = json.loads(subscribe.sites)
if siteids:
domains = self.siteoper.get_domains_by_ids(siteids)
except JSONDecodeError:
pass
# 识别媒体信息 # 识别媒体信息
mediainfo: MediaInfo = self.recognize_media(meta=meta, mtype=meta.type, mediainfo: MediaInfo = self.recognize_media(meta=meta, mtype=meta.type,
tmdbid=subscribe.tmdbid, tmdbid=subscribe.tmdbid,
@ -593,7 +604,9 @@ class SubscribeChain(ChainBase):
# 遍历缓存种子 # 遍历缓存种子
_match_context = [] _match_context = []
for domain, contexts in torrents.items(): for domain, contexts in torrents.items():
logger.info(f'开始匹配站点:{domain},共缓存了 {len(contexts)} 个种子...') if domains and domain not in domains:
continue
logger.debug(f'开始匹配站点:{domain},共缓存了 {len(contexts)} 个种子...')
for context in contexts: for context in contexts:
# 检查是否匹配 # 检查是否匹配
torrent_meta = context.meta_info torrent_meta = context.meta_info

View File

@ -153,12 +153,15 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
# 所有站点索引 # 所有站点索引
indexers = self.siteshelper.get_indexers() indexers = self.siteshelper.get_indexers()
# 需要刷新的站点domain
domains = []
# 遍历站点缓存资源 # 遍历站点缓存资源
for indexer in indexers: for indexer in indexers:
# 未开启的站点不刷新 # 未开启的站点不刷新
if sites and indexer.get("id") not in sites: if sites and indexer.get("id") not in sites:
continue continue
domain = StringUtils.get_url_domain(indexer.get("domain")) domain = StringUtils.get_url_domain(indexer.get("domain"))
domains.append(domain)
if stype == "spider": if stype == "spider":
# 刷新首页种子 # 刷新首页种子
torrents: List[TorrentInfo] = self.browse(domain=domain) torrents: List[TorrentInfo] = self.browse(domain=domain)
@ -219,7 +222,9 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
else: else:
self.save_cache(torrents_cache, self._rss_file) self.save_cache(torrents_cache, self._rss_file)
# 返回 # 去除不在站点范围内的缓存种子
if sites and torrents_cache:
torrents_cache = {k: v for k, v in torrents_cache.items() if k in domains}
return torrents_cache return torrents_cache
def __renew_rss_url(self, domain: str, site: dict): def __renew_rss_url(self, domain: str, site: dict):

View File

@ -69,6 +69,12 @@ class Site(Base):
result = db.query(Site).order_by(Site.pri).all() result = db.query(Site).order_by(Site.pri).all()
return list(result) return list(result)
@staticmethod
@db_query
def get_domains_by_ids(db: Session, ids: list):
result = db.query(Site.domain).filter(Site.id.in_(ids)).all()
return [r[0] for r in result]
@staticmethod @staticmethod
@db_update @db_update
def reset(db: Session): def reset(db: Session):

View File

@ -63,6 +63,12 @@ class SiteOper(DbOper):
""" """
return Site.get_by_domain(self._db, domain) return Site.get_by_domain(self._db, domain)
def get_domains_by_ids(self, ids: List[int]) -> List[str]:
"""
按ID获取站点域名
"""
return Site.get_domains_by_ids(self._db, ids)
def exists(self, domain: str) -> bool: def exists(self, domain: str) -> bool:
""" """
判断站点是否存在 判断站点是否存在