fix chain depends

This commit is contained in:
jxxghp 2023-06-09 23:44:55 +08:00
parent b9012c5666
commit 51da075a65
6 changed files with 71 additions and 80 deletions

View File

@ -2,12 +2,12 @@ from pathlib import Path
from typing import Optional
from app.chain import ChainBase
from app.chain.common import CommonChain
from app.chain.download import DownloadChain
from app.chain.search import SearchChain
from app.chain.subscribe import SubscribeChain
from app.core.config import settings
from app.core.meta_info import MetaInfo
from app.core.context import MediaInfo
from app.db.subscribes import Subscribes
from app.helper.rss import RssHelper
from app.log import logger
@ -24,9 +24,9 @@ class DoubanSyncChain(ChainBase):
def __init__(self):
super().__init__()
self.rsshelper = RssHelper()
self.common = CommonChain()
self.downloadchain = DownloadChain()
self.searchchain = SearchChain()
self.subscribes = Subscribes()
self.subscribechain = SubscribeChain()
def process(self):
"""
@ -74,7 +74,7 @@ class DoubanSyncChain(ChainBase):
# 加入缓存
caches.append(douban_id)
# 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo)
exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo)
if exist_flag:
logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在')
continue
@ -85,7 +85,7 @@ class DoubanSyncChain(ChainBase):
logger.warn(f'{mediainfo.get_title_string()} 未搜索到资源')
continue
# 自动下载
downloads, lefts = self.common.batch_download(contexts=contexts, need_tvs=no_exists)
downloads, lefts = self.downloadchain.batch_download(contexts=contexts, need_tvs=no_exists)
if downloads and not lefts:
# 全部下载完成
logger.info(f'{mediainfo.get_title_string()} 下载完成')
@ -93,14 +93,11 @@ class DoubanSyncChain(ChainBase):
# 未完成下载
logger.info(f'{mediainfo.get_title_string()} 未下载未完整,添加订阅 ...')
# 添加订阅
state, msg = self.subscribes.add(mediainfo,
season=meta.begin_season)
if state:
# 订阅成功
self.common.post_message(
title=f"{mediainfo.get_title_string()} 已添加订阅",
text="来自:豆瓣想看",
image=mediainfo.get_message_image())
self.subscribechain.process(title=mediainfo.title,
mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
season=meta.begin_season,
username="豆瓣想看")
logger.info(f"用户 {user_id} 豆瓣想看同步完成")
# 保存缓存

View File

@ -11,7 +11,7 @@ from app.utils.string import StringUtils
from app.utils.types import MediaType
class CommonChain(ChainBase):
class DownloadChain(ChainBase):
def __init__(self):
super().__init__()

View File

@ -1,11 +1,10 @@
from typing import Optional, List
from app.chain import ChainBase
from app.chain.common import CommonChain
from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.config import settings
from app.core.meta_info import MetaInfo
from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.meta import MetaBase
from app.core.meta_info import MetaInfo
from app.helper.sites import SitesHelper
from app.log import logger
@ -17,7 +16,6 @@ class SearchChain(ChainBase):
def __init__(self):
super().__init__()
self.common = CommonChain()
self.siteshelper = SitesHelper()
def process(self, meta: MetaBase, mediainfo: MediaInfo,

View File

@ -1,7 +1,7 @@
from typing import Dict, List, Optional
from app.chain import ChainBase
from app.chain.common import CommonChain
from app.chain.download import DownloadChain
from app.chain.search import SearchChain
from app.core.meta_info import MetaInfo
from app.core.context import TorrentInfo, Context, MediaInfo
@ -23,7 +23,7 @@ class SubscribeChain(ChainBase):
def __init__(self):
super().__init__()
self.common = CommonChain()
self.downloadchain = DownloadChain()
self.searchchain = SearchChain()
self.subscribes = Subscribes()
self.siteshelper = SitesHelper()
@ -32,6 +32,7 @@ class SubscribeChain(ChainBase):
mtype: MediaType = None,
tmdbid: str = None,
season: int = None,
userid: str = None,
username: str = None,
**kwargs) -> bool:
"""
@ -60,10 +61,17 @@ class SubscribeChain(ChainBase):
state, err_msg = self.subscribes.add(mediainfo, season=season, **kwargs)
if state:
logger.info(f'{mediainfo.get_title_string()} {err_msg}')
# 发回原用户
self.post_message(title=f"{mediainfo.get_title_string()}{metainfo.get_season_string()} "
f"添加订阅失败!",
text=f"{err_msg}",
image=mediainfo.get_message_image(),
userid=userid)
else:
logger.error(f'{mediainfo.get_title_string()} 添加订阅成功')
self.common.post_message(title=f"{mediainfo.get_title_string()} 已添加订阅",
text=f"来自用户:{username}",
# 广而告之
self.post_message(title=f"{mediainfo.get_title_string()}{metainfo.get_season_string()} 已添加订阅",
text=f"来自用户:{username or userid}",
image=mediainfo.get_message_image())
# 返回结果
return state
@ -96,7 +104,7 @@ class SubscribeChain(ChainBase):
logger.warn(f'未识别到媒体信息,标题:{subscribe.name}tmdbid{subscribe.tmdbid}')
continue
# 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo)
exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo)
if exist_flag:
logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在,完成订阅')
self.subscribes.delete(subscribe.id)
@ -110,7 +118,7 @@ class SubscribeChain(ChainBase):
logger.warn(f'{subscribe.keyword or subscribe.name} 未搜索到资源')
continue
# 自动下载
downloads, lefts = self.common.batch_download(contexts=contexts, need_tvs=no_exists)
downloads, lefts = self.downloadchain.batch_download(contexts=contexts, need_tvs=no_exists)
if downloads and not lefts:
# 全部下载完成
logger.info(f'{mediainfo.get_title_string()} 下载完成,完成订阅')
@ -180,7 +188,7 @@ class SubscribeChain(ChainBase):
logger.warn(f'未识别到媒体信息,标题:{subscribe.name}tmdbid{subscribe.tmdbid}')
continue
# 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo)
exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo)
if exist_flag:
logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在,完成订阅')
self.subscribes.delete(subscribe.id)
@ -206,7 +214,7 @@ class SubscribeChain(ChainBase):
logger.info(f'{mediainfo.get_title_string()} 匹配完成,共匹配到{len(_match_context)}个资源')
if _match_context:
# 批量择优下载
downloads, lefts = self.common.batch_download(contexts=_match_context, need_tvs=no_exists)
downloads, lefts = self.downloadchain.batch_download(contexts=_match_context, need_tvs=no_exists)
if downloads and not lefts:
# 全部下载完成
logger.info(f'{mediainfo.get_title_string()} 下载完成,完成订阅')

View File

@ -1,11 +1,11 @@
from typing import Any
from app.chain.common import *
from app.chain.download import *
from app.chain.search import SearchChain
from app.chain.subscribe import SubscribeChain
from app.core.context import MediaInfo, TorrentInfo
from app.core.meta_info import MetaInfo
from app.core.event_manager import EventManager
from app.db.subscribes import Subscribes
from app.log import logger
from app.utils.types import EventType
@ -27,8 +27,8 @@ class UserMessageChain(ChainBase):
def __init__(self):
super().__init__()
self.common = CommonChain()
self.subscribes = Subscribes()
self.downloadchain = DownloadChain()
self.subscribechain = SubscribeChain()
self.searchchain = SearchChain()
self.torrent = TorrentHelper()
self.eventmanager = EventManager()
@ -43,6 +43,7 @@ class UserMessageChain(ChainBase):
return
# 用户ID
userid = info.get('userid')
username = info.get('username')
if not userid:
logger.debug(f'未识别到用户ID{body}{form}{args}')
return
@ -60,7 +61,7 @@ class UserMessageChain(ChainBase):
"cmd": text
}
)
self.common.post_message(title=f"正在运行,请稍候 ...", userid=userid)
self.post_message(title=f"正在运行,请稍候 ...", userid=userid)
elif text.isdigit():
# 缓存
@ -70,7 +71,7 @@ class UserMessageChain(ChainBase):
or not cache_data.get('items') \
or len(cache_data.get('items')) < int(text):
# 发送消息
self.common.post_message(title="输入有误!", userid=userid)
self.post_message(title="输入有误!", userid=userid)
return
# 缓存类型
cache_type: str = cache_data.get('type')
@ -84,17 +85,17 @@ class UserMessageChain(ChainBase):
exists: dict = self.media_exists(mediainfo=mediainfo)
if exists:
# 已存在
self.common.post_message(
self.post_message(
title=f"{mediainfo.type.value} {mediainfo.get_title_string()} 媒体库中已存在", userid=userid)
return
logger.info(f"{mediainfo.get_title_string()} 媒体库中不存在,开始搜索 ...")
self.common.post_message(
self.post_message(
title=f"开始搜索 {mediainfo.type.value} {mediainfo.get_title_string()} ...", userid=userid)
# 搜索种子
contexts = self.searchchain.process(meta=self._current_meta, mediainfo=mediainfo)
if not contexts:
# 没有数据
self.common.post_message(title=f"{mediainfo.title} 未搜索到资源!", userid=userid)
self.post_message(title=f"{mediainfo.title} 未搜索到资源!", userid=userid)
return
# 更新缓存
self._user_cache[userid] = {
@ -109,34 +110,23 @@ class UserMessageChain(ChainBase):
elif cache_type == "Subscribe":
# 订阅媒体
mediainfo: MediaInfo = cache_list[int(text) - 1]
# 补充识别媒体信息
mediainfo: MediaInfo = self.recognize_media(meta=self._current_meta, tmdbid=mediainfo.tmdb_id)
if not mediainfo:
logger.warn(f'未识别到媒体信息tmdbid{mediainfo.tmdb_id}')
return
self._current_media = mediainfo
state, msg = self.subscribes.add(mediainfo,
season=self._current_meta.begin_season)
if state:
# 订阅成功
self.common.post_message(
title=f"{mediainfo.get_title_string()} 已添加订阅",
image=mediainfo.get_message_image(),
userid=userid)
else:
# 订阅失败
self.common.post_message(title=f"{mediainfo.title} 添加订阅失败:{msg}", userid=userid)
self.subscribechain.process(title=mediainfo.title,
mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
season=self._current_meta.begin_season,
userid=userid,
username=username)
elif cache_type == "Torrent":
if int(text) == 0:
# 自动选择下载
# 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=self._current_media)
exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=self._current_media)
if exist_flag:
self.common.post_message(title=f"{self._current_media.get_title_string()} 媒体库中已存在",
self.post_message(title=f"{self._current_media.get_title_string()} 媒体库中已存在",
userid=userid)
return
# 批量下载
downloads, lefts = self.common.batch_download(contexts=cache_list,
downloads, lefts = self.downloadchain.batch_download(contexts=cache_list,
need_tvs=no_exists,
userid=userid)
if downloads and not lefts:
@ -146,14 +136,12 @@ class UserMessageChain(ChainBase):
# 未完成下载
logger.info(f'{self._current_media.get_title_string()} 未下载未完整,添加订阅 ...')
# 添加订阅
state, msg = self.subscribes.add(self._current_media,
season=self._current_meta.begin_season)
if state:
# 订阅成功
self.common.post_message(
title=f"{self._current_media.get_title_string()} 已添加订阅",
text=f"来自用户:{userid}",
image=self._current_media.get_message_image())
self.subscribechain.process(title=self._current_media.title,
mtype=self._current_media.type,
tmdbid=self._current_media.tmdb_id,
season=self._current_meta.begin_season,
userid=userid,
username=username)
else:
# 下载种子
context: Context = cache_list[int(text) - 1]
@ -181,24 +169,24 @@ class UserMessageChain(ChainBase):
# 发送消息
if not state:
# 下载失败
self.common.post_message(title=f"{torrent.title} 添加下载失败!",
self.post_message(title=f"{torrent.title} 添加下载失败!",
text=f"错误信息:{msg}",
userid=userid)
return
# 下载成功,发送通知
self.common.post_download_message(meta=meta, mediainfo=self._current_media, torrent=torrent)
self.downloadchain.post_download_message(meta=meta, mediainfo=self._current_media, torrent=torrent)
elif text.lower() == "p":
# 上一页
cache_data: dict = self._user_cache.get(userid)
if not cache_data:
# 没有缓存
self.common.post_message(title="输入有误!", userid=userid)
self.post_message(title="输入有误!", userid=userid)
return
if self._current_page == 0:
# 第一页
self.common.post_message(title="已经是第一页了!", userid=userid)
self.post_message(title="已经是第一页了!", userid=userid)
return
cache_type: str = cache_data.get('type')
cache_list: list = cache_data.get('items')
@ -222,7 +210,7 @@ class UserMessageChain(ChainBase):
cache_data: dict = self._user_cache.get(userid)
if not cache_data:
# 没有缓存
self.common.post_message(title="输入有误!", userid=userid)
self.post_message(title="输入有误!", userid=userid)
return
cache_type: str = cache_data.get('type')
cache_list: list = cache_data.get('items')
@ -232,7 +220,7 @@ class UserMessageChain(ChainBase):
cache_list = cache_list[self._current_page * self._page_size:(self._current_page + 1) * self._page_size]
if not cache_list:
# 没有数据
self.common.post_message(title="已经是最后一页了!", userid=userid)
self.post_message(title="已经是最后一页了!", userid=userid)
return
else:
if cache_type == "Torrent":
@ -257,7 +245,7 @@ class UserMessageChain(ChainBase):
# 识别
meta = MetaInfo(title)
if not meta.get_name():
self.common.post_message(title="无法识别输入内容!", userid=userid)
self.post_message(title="无法识别输入内容!", userid=userid)
return
# 合并信息
if mtype:
@ -275,7 +263,7 @@ class UserMessageChain(ChainBase):
logger.info(f"开始搜索:{meta.get_name()}")
medias: Optional[List[MediaInfo]] = self.search_medias(meta=meta)
if not medias:
self.common.post_message(title=f"{meta.get_name()} 没有找到对应的媒体信息!", userid=userid)
self.post_message(title=f"{meta.get_name()} 没有找到对应的媒体信息!", userid=userid)
return
self._user_cache[userid] = {
'type': action,

View File

@ -2,7 +2,7 @@
from unittest import TestCase
from app.chain.common import CommonChain
from app.chain.download import DownloadChain
from app.chain.identify import IdentifyChain
@ -16,5 +16,5 @@ class RecognizeTest(TestCase):
def test_recognize(self):
result = IdentifyChain().process(title="我和我的祖国 2019")
self.assertEqual(str(result.media_info.tmdb_id), '612845')
exists = CommonChain().get_no_exists_info(result.media_info)
exists = DownloadChain().get_no_exists_info(result.media_info)
self.assertTrue(exists[0])