diff --git a/app/chain/douban_sync.py b/app/chain/douban_sync.py index 50876665..d7ae1980 100644 --- a/app/chain/douban_sync.py +++ b/app/chain/douban_sync.py @@ -2,12 +2,12 @@ from pathlib import Path from typing import Optional from app.chain import ChainBase -from app.chain.common import CommonChain +from app.chain.download import DownloadChain from app.chain.search import SearchChain +from app.chain.subscribe import SubscribeChain from app.core.config import settings from app.core.meta_info import MetaInfo from app.core.context import MediaInfo -from app.db.subscribes import Subscribes from app.helper.rss import RssHelper from app.log import logger @@ -24,9 +24,9 @@ class DoubanSyncChain(ChainBase): def __init__(self): super().__init__() self.rsshelper = RssHelper() - self.common = CommonChain() + self.downloadchain = DownloadChain() self.searchchain = SearchChain() - self.subscribes = Subscribes() + self.subscribechain = SubscribeChain() def process(self): """ @@ -74,7 +74,7 @@ class DoubanSyncChain(ChainBase): # 加入缓存 caches.append(douban_id) # 查询缺失的媒体信息 - exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo) + exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo) if exist_flag: logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在') continue @@ -85,7 +85,7 @@ class DoubanSyncChain(ChainBase): logger.warn(f'{mediainfo.get_title_string()} 未搜索到资源') continue # 自动下载 - downloads, lefts = self.common.batch_download(contexts=contexts, need_tvs=no_exists) + downloads, lefts = self.downloadchain.batch_download(contexts=contexts, need_tvs=no_exists) if downloads and not lefts: # 全部下载完成 logger.info(f'{mediainfo.get_title_string()} 下载完成') @@ -93,14 +93,11 @@ class DoubanSyncChain(ChainBase): # 未完成下载 logger.info(f'{mediainfo.get_title_string()} 未下载未完整,添加订阅 ...') # 添加订阅 - state, msg = self.subscribes.add(mediainfo, - season=meta.begin_season) - if state: - # 订阅成功 - self.common.post_message( - title=f"{mediainfo.get_title_string()} 已添加订阅", - text="来自:豆瓣想看", - image=mediainfo.get_message_image()) + self.subscribechain.process(title=mediainfo.title, + mtype=mediainfo.type, + tmdbid=mediainfo.tmdb_id, + season=meta.begin_season, + username="豆瓣想看") logger.info(f"用户 {user_id} 豆瓣想看同步完成") # 保存缓存 diff --git a/app/chain/common.py b/app/chain/download.py similarity index 99% rename from app/chain/common.py rename to app/chain/download.py index 8dc6e44b..d82e9477 100644 --- a/app/chain/common.py +++ b/app/chain/download.py @@ -11,7 +11,7 @@ from app.utils.string import StringUtils from app.utils.types import MediaType -class CommonChain(ChainBase): +class DownloadChain(ChainBase): def __init__(self): super().__init__() diff --git a/app/chain/search.py b/app/chain/search.py index 68cf3aa1..3eb1262d 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -1,11 +1,10 @@ from typing import Optional, List from app.chain import ChainBase -from app.chain.common import CommonChain -from app.core.context import Context, MediaInfo, TorrentInfo from app.core.config import settings -from app.core.meta_info import MetaInfo +from app.core.context import Context, MediaInfo, TorrentInfo from app.core.meta import MetaBase +from app.core.meta_info import MetaInfo from app.helper.sites import SitesHelper from app.log import logger @@ -17,7 +16,6 @@ class SearchChain(ChainBase): def __init__(self): super().__init__() - self.common = CommonChain() self.siteshelper = SitesHelper() def process(self, meta: MetaBase, mediainfo: MediaInfo, diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 870bb034..f7c0b82b 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional from app.chain import ChainBase -from app.chain.common import CommonChain +from app.chain.download import DownloadChain from app.chain.search import SearchChain from app.core.meta_info import MetaInfo from app.core.context import TorrentInfo, Context, MediaInfo @@ -23,7 +23,7 @@ class SubscribeChain(ChainBase): def __init__(self): super().__init__() - self.common = CommonChain() + self.downloadchain = DownloadChain() self.searchchain = SearchChain() self.subscribes = Subscribes() self.siteshelper = SitesHelper() @@ -32,6 +32,7 @@ class SubscribeChain(ChainBase): mtype: MediaType = None, tmdbid: str = None, season: int = None, + userid: str = None, username: str = None, **kwargs) -> bool: """ @@ -60,11 +61,18 @@ class SubscribeChain(ChainBase): state, err_msg = self.subscribes.add(mediainfo, season=season, **kwargs) if state: logger.info(f'{mediainfo.get_title_string()} {err_msg}') + # 发回原用户 + self.post_message(title=f"{mediainfo.get_title_string()}{metainfo.get_season_string()} " + f"添加订阅失败!", + text=f"{err_msg}", + image=mediainfo.get_message_image(), + userid=userid) else: logger.error(f'{mediainfo.get_title_string()} 添加订阅成功') - self.common.post_message(title=f"{mediainfo.get_title_string()} 已添加订阅", - text=f"来自用户:{username}", - image=mediainfo.get_message_image()) + # 广而告之 + self.post_message(title=f"{mediainfo.get_title_string()}{metainfo.get_season_string()} 已添加订阅", + text=f"来自用户:{username or userid}", + image=mediainfo.get_message_image()) # 返回结果 return state @@ -96,7 +104,7 @@ class SubscribeChain(ChainBase): logger.warn(f'未识别到媒体信息,标题:{subscribe.name},tmdbid:{subscribe.tmdbid}') continue # 查询缺失的媒体信息 - exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo) + exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo) if exist_flag: logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在,完成订阅') self.subscribes.delete(subscribe.id) @@ -110,7 +118,7 @@ class SubscribeChain(ChainBase): logger.warn(f'{subscribe.keyword or subscribe.name} 未搜索到资源') continue # 自动下载 - downloads, lefts = self.common.batch_download(contexts=contexts, need_tvs=no_exists) + downloads, lefts = self.downloadchain.batch_download(contexts=contexts, need_tvs=no_exists) if downloads and not lefts: # 全部下载完成 logger.info(f'{mediainfo.get_title_string()} 下载完成,完成订阅') @@ -180,7 +188,7 @@ class SubscribeChain(ChainBase): logger.warn(f'未识别到媒体信息,标题:{subscribe.name},tmdbid:{subscribe.tmdbid}') continue # 查询缺失的媒体信息 - exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo) + exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo) if exist_flag: logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在,完成订阅') self.subscribes.delete(subscribe.id) @@ -206,7 +214,7 @@ class SubscribeChain(ChainBase): logger.info(f'{mediainfo.get_title_string()} 匹配完成,共匹配到{len(_match_context)}个资源') if _match_context: # 批量择优下载 - downloads, lefts = self.common.batch_download(contexts=_match_context, need_tvs=no_exists) + downloads, lefts = self.downloadchain.batch_download(contexts=_match_context, need_tvs=no_exists) if downloads and not lefts: # 全部下载完成 logger.info(f'{mediainfo.get_title_string()} 下载完成,完成订阅') diff --git a/app/chain/user_message.py b/app/chain/user_message.py index 944594a6..c7f7140c 100644 --- a/app/chain/user_message.py +++ b/app/chain/user_message.py @@ -1,11 +1,11 @@ from typing import Any -from app.chain.common import * +from app.chain.download import * from app.chain.search import SearchChain +from app.chain.subscribe import SubscribeChain from app.core.context import MediaInfo, TorrentInfo from app.core.meta_info import MetaInfo from app.core.event_manager import EventManager -from app.db.subscribes import Subscribes from app.log import logger from app.utils.types import EventType @@ -27,8 +27,8 @@ class UserMessageChain(ChainBase): def __init__(self): super().__init__() - self.common = CommonChain() - self.subscribes = Subscribes() + self.downloadchain = DownloadChain() + self.subscribechain = SubscribeChain() self.searchchain = SearchChain() self.torrent = TorrentHelper() self.eventmanager = EventManager() @@ -43,6 +43,7 @@ class UserMessageChain(ChainBase): return # 用户ID userid = info.get('userid') + username = info.get('username') if not userid: logger.debug(f'未识别到用户ID:{body}{form}{args}') return @@ -60,7 +61,7 @@ class UserMessageChain(ChainBase): "cmd": text } ) - self.common.post_message(title=f"正在运行,请稍候 ...", userid=userid) + self.post_message(title=f"正在运行,请稍候 ...", userid=userid) elif text.isdigit(): # 缓存 @@ -70,7 +71,7 @@ class UserMessageChain(ChainBase): or not cache_data.get('items') \ or len(cache_data.get('items')) < int(text): # 发送消息 - self.common.post_message(title="输入有误!", userid=userid) + self.post_message(title="输入有误!", userid=userid) return # 缓存类型 cache_type: str = cache_data.get('type') @@ -84,17 +85,17 @@ class UserMessageChain(ChainBase): exists: dict = self.media_exists(mediainfo=mediainfo) if exists: # 已存在 - self.common.post_message( + self.post_message( title=f"{mediainfo.type.value} {mediainfo.get_title_string()} 媒体库中已存在", userid=userid) return logger.info(f"{mediainfo.get_title_string()} 媒体库中不存在,开始搜索 ...") - self.common.post_message( + self.post_message( title=f"开始搜索 {mediainfo.type.value} {mediainfo.get_title_string()} ...", userid=userid) # 搜索种子 contexts = self.searchchain.process(meta=self._current_meta, mediainfo=mediainfo) if not contexts: # 没有数据 - self.common.post_message(title=f"{mediainfo.title} 未搜索到资源!", userid=userid) + self.post_message(title=f"{mediainfo.title} 未搜索到资源!", userid=userid) return # 更新缓存 self._user_cache[userid] = { @@ -109,36 +110,25 @@ class UserMessageChain(ChainBase): elif cache_type == "Subscribe": # 订阅媒体 mediainfo: MediaInfo = cache_list[int(text) - 1] - # 补充识别媒体信息 - mediainfo: MediaInfo = self.recognize_media(meta=self._current_meta, tmdbid=mediainfo.tmdb_id) - if not mediainfo: - logger.warn(f'未识别到媒体信息,tmdbid:{mediainfo.tmdb_id}') - return - self._current_media = mediainfo - state, msg = self.subscribes.add(mediainfo, - season=self._current_meta.begin_season) - if state: - # 订阅成功 - self.common.post_message( - title=f"{mediainfo.get_title_string()} 已添加订阅", - image=mediainfo.get_message_image(), - userid=userid) - else: - # 订阅失败 - self.common.post_message(title=f"{mediainfo.title} 添加订阅失败:{msg}", userid=userid) + self.subscribechain.process(title=mediainfo.title, + mtype=mediainfo.type, + tmdbid=mediainfo.tmdb_id, + season=self._current_meta.begin_season, + userid=userid, + username=username) elif cache_type == "Torrent": if int(text) == 0: # 自动选择下载 # 查询缺失的媒体信息 - exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=self._current_media) + exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=self._current_media) if exist_flag: - self.common.post_message(title=f"{self._current_media.get_title_string()} 媒体库中已存在", - userid=userid) + self.post_message(title=f"{self._current_media.get_title_string()} 媒体库中已存在", + userid=userid) return # 批量下载 - downloads, lefts = self.common.batch_download(contexts=cache_list, - need_tvs=no_exists, - userid=userid) + downloads, lefts = self.downloadchain.batch_download(contexts=cache_list, + need_tvs=no_exists, + userid=userid) if downloads and not lefts: # 全部下载完成 logger.info(f'{self._current_media.get_title_string()} 下载完成') @@ -146,14 +136,12 @@ class UserMessageChain(ChainBase): # 未完成下载 logger.info(f'{self._current_media.get_title_string()} 未下载未完整,添加订阅 ...') # 添加订阅 - state, msg = self.subscribes.add(self._current_media, - season=self._current_meta.begin_season) - if state: - # 订阅成功 - self.common.post_message( - title=f"{self._current_media.get_title_string()} 已添加订阅", - text=f"来自用户:{userid}", - image=self._current_media.get_message_image()) + self.subscribechain.process(title=self._current_media.title, + mtype=self._current_media.type, + tmdbid=self._current_media.tmdb_id, + season=self._current_meta.begin_season, + userid=userid, + username=username) else: # 下载种子 context: Context = cache_list[int(text) - 1] @@ -181,24 +169,24 @@ class UserMessageChain(ChainBase): # 发送消息 if not state: # 下载失败 - self.common.post_message(title=f"{torrent.title} 添加下载失败!", - text=f"错误信息:{msg}", - userid=userid) + self.post_message(title=f"{torrent.title} 添加下载失败!", + text=f"错误信息:{msg}", + userid=userid) return # 下载成功,发送通知 - self.common.post_download_message(meta=meta, mediainfo=self._current_media, torrent=torrent) + self.downloadchain.post_download_message(meta=meta, mediainfo=self._current_media, torrent=torrent) elif text.lower() == "p": # 上一页 cache_data: dict = self._user_cache.get(userid) if not cache_data: # 没有缓存 - self.common.post_message(title="输入有误!", userid=userid) + self.post_message(title="输入有误!", userid=userid) return if self._current_page == 0: # 第一页 - self.common.post_message(title="已经是第一页了!", userid=userid) + self.post_message(title="已经是第一页了!", userid=userid) return cache_type: str = cache_data.get('type') cache_list: list = cache_data.get('items') @@ -222,7 +210,7 @@ class UserMessageChain(ChainBase): cache_data: dict = self._user_cache.get(userid) if not cache_data: # 没有缓存 - self.common.post_message(title="输入有误!", userid=userid) + self.post_message(title="输入有误!", userid=userid) return cache_type: str = cache_data.get('type') cache_list: list = cache_data.get('items') @@ -232,7 +220,7 @@ class UserMessageChain(ChainBase): cache_list = cache_list[self._current_page * self._page_size:(self._current_page + 1) * self._page_size] if not cache_list: # 没有数据 - self.common.post_message(title="已经是最后一页了!", userid=userid) + self.post_message(title="已经是最后一页了!", userid=userid) return else: if cache_type == "Torrent": @@ -257,7 +245,7 @@ class UserMessageChain(ChainBase): # 识别 meta = MetaInfo(title) if not meta.get_name(): - self.common.post_message(title="无法识别输入内容!", userid=userid) + self.post_message(title="无法识别输入内容!", userid=userid) return # 合并信息 if mtype: @@ -275,7 +263,7 @@ class UserMessageChain(ChainBase): logger.info(f"开始搜索:{meta.get_name()}") medias: Optional[List[MediaInfo]] = self.search_medias(meta=meta) if not medias: - self.common.post_message(title=f"{meta.get_name()} 没有找到对应的媒体信息!", userid=userid) + self.post_message(title=f"{meta.get_name()} 没有找到对应的媒体信息!", userid=userid) return self._user_cache[userid] = { 'type': action, diff --git a/tests/test_recognize.py b/tests/test_recognize.py index 9dc1290b..c5220a43 100644 --- a/tests/test_recognize.py +++ b/tests/test_recognize.py @@ -2,7 +2,7 @@ from unittest import TestCase -from app.chain.common import CommonChain +from app.chain.download import DownloadChain from app.chain.identify import IdentifyChain @@ -16,5 +16,5 @@ class RecognizeTest(TestCase): def test_recognize(self): result = IdentifyChain().process(title="我和我的祖国 2019") self.assertEqual(str(result.media_info.tmdb_id), '612845') - exists = CommonChain().get_no_exists_info(result.media_info) + exists = DownloadChain().get_no_exists_info(result.media_info) self.assertTrue(exists[0])