fix chain depends

This commit is contained in:
jxxghp 2023-06-09 23:44:55 +08:00
parent b9012c5666
commit 51da075a65
6 changed files with 71 additions and 80 deletions

View File

@ -2,12 +2,12 @@ from pathlib import Path
from typing import Optional from typing import Optional
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.common import CommonChain from app.chain.download import DownloadChain
from app.chain.search import SearchChain from app.chain.search import SearchChain
from app.chain.subscribe import SubscribeChain
from app.core.config import settings from app.core.config import settings
from app.core.meta_info import MetaInfo from app.core.meta_info import MetaInfo
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.db.subscribes import Subscribes
from app.helper.rss import RssHelper from app.helper.rss import RssHelper
from app.log import logger from app.log import logger
@ -24,9 +24,9 @@ class DoubanSyncChain(ChainBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.rsshelper = RssHelper() self.rsshelper = RssHelper()
self.common = CommonChain() self.downloadchain = DownloadChain()
self.searchchain = SearchChain() self.searchchain = SearchChain()
self.subscribes = Subscribes() self.subscribechain = SubscribeChain()
def process(self): def process(self):
""" """
@ -74,7 +74,7 @@ class DoubanSyncChain(ChainBase):
# 加入缓存 # 加入缓存
caches.append(douban_id) caches.append(douban_id)
# 查询缺失的媒体信息 # 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo) exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo)
if exist_flag: if exist_flag:
logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在') logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在')
continue continue
@ -85,7 +85,7 @@ class DoubanSyncChain(ChainBase):
logger.warn(f'{mediainfo.get_title_string()} 未搜索到资源') logger.warn(f'{mediainfo.get_title_string()} 未搜索到资源')
continue continue
# 自动下载 # 自动下载
downloads, lefts = self.common.batch_download(contexts=contexts, need_tvs=no_exists) downloads, lefts = self.downloadchain.batch_download(contexts=contexts, need_tvs=no_exists)
if downloads and not lefts: if downloads and not lefts:
# 全部下载完成 # 全部下载完成
logger.info(f'{mediainfo.get_title_string()} 下载完成') logger.info(f'{mediainfo.get_title_string()} 下载完成')
@ -93,14 +93,11 @@ class DoubanSyncChain(ChainBase):
# 未完成下载 # 未完成下载
logger.info(f'{mediainfo.get_title_string()} 未下载未完整,添加订阅 ...') logger.info(f'{mediainfo.get_title_string()} 未下载未完整,添加订阅 ...')
# 添加订阅 # 添加订阅
state, msg = self.subscribes.add(mediainfo, self.subscribechain.process(title=mediainfo.title,
season=meta.begin_season) mtype=mediainfo.type,
if state: tmdbid=mediainfo.tmdb_id,
# 订阅成功 season=meta.begin_season,
self.common.post_message( username="豆瓣想看")
title=f"{mediainfo.get_title_string()} 已添加订阅",
text="来自:豆瓣想看",
image=mediainfo.get_message_image())
logger.info(f"用户 {user_id} 豆瓣想看同步完成") logger.info(f"用户 {user_id} 豆瓣想看同步完成")
# 保存缓存 # 保存缓存

View File

@ -11,7 +11,7 @@ from app.utils.string import StringUtils
from app.utils.types import MediaType from app.utils.types import MediaType
class CommonChain(ChainBase): class DownloadChain(ChainBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1,11 +1,10 @@
from typing import Optional, List from typing import Optional, List
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.common import CommonChain
from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.config import settings from app.core.config import settings
from app.core.meta_info import MetaInfo from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.meta import MetaBase from app.core.meta import MetaBase
from app.core.meta_info import MetaInfo
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper
from app.log import logger from app.log import logger
@ -17,7 +16,6 @@ class SearchChain(ChainBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.common = CommonChain()
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
def process(self, meta: MetaBase, mediainfo: MediaInfo, def process(self, meta: MetaBase, mediainfo: MediaInfo,

View File

@ -1,7 +1,7 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.common import CommonChain from app.chain.download import DownloadChain
from app.chain.search import SearchChain from app.chain.search import SearchChain
from app.core.meta_info import MetaInfo from app.core.meta_info import MetaInfo
from app.core.context import TorrentInfo, Context, MediaInfo from app.core.context import TorrentInfo, Context, MediaInfo
@ -23,7 +23,7 @@ class SubscribeChain(ChainBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.common = CommonChain() self.downloadchain = DownloadChain()
self.searchchain = SearchChain() self.searchchain = SearchChain()
self.subscribes = Subscribes() self.subscribes = Subscribes()
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
@ -32,6 +32,7 @@ class SubscribeChain(ChainBase):
mtype: MediaType = None, mtype: MediaType = None,
tmdbid: str = None, tmdbid: str = None,
season: int = None, season: int = None,
userid: str = None,
username: str = None, username: str = None,
**kwargs) -> bool: **kwargs) -> bool:
""" """
@ -60,10 +61,17 @@ class SubscribeChain(ChainBase):
state, err_msg = self.subscribes.add(mediainfo, season=season, **kwargs) state, err_msg = self.subscribes.add(mediainfo, season=season, **kwargs)
if state: if state:
logger.info(f'{mediainfo.get_title_string()} {err_msg}') logger.info(f'{mediainfo.get_title_string()} {err_msg}')
# 发回原用户
self.post_message(title=f"{mediainfo.get_title_string()}{metainfo.get_season_string()} "
f"添加订阅失败!",
text=f"{err_msg}",
image=mediainfo.get_message_image(),
userid=userid)
else: else:
logger.error(f'{mediainfo.get_title_string()} 添加订阅成功') logger.error(f'{mediainfo.get_title_string()} 添加订阅成功')
self.common.post_message(title=f"{mediainfo.get_title_string()} 已添加订阅", # 广而告之
text=f"来自用户:{username}", self.post_message(title=f"{mediainfo.get_title_string()}{metainfo.get_season_string()} 已添加订阅",
text=f"来自用户:{username or userid}",
image=mediainfo.get_message_image()) image=mediainfo.get_message_image())
# 返回结果 # 返回结果
return state return state
@ -96,7 +104,7 @@ class SubscribeChain(ChainBase):
logger.warn(f'未识别到媒体信息,标题:{subscribe.name}tmdbid{subscribe.tmdbid}') logger.warn(f'未识别到媒体信息,标题:{subscribe.name}tmdbid{subscribe.tmdbid}')
continue continue
# 查询缺失的媒体信息 # 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo) exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo)
if exist_flag: if exist_flag:
logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在,完成订阅') logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在,完成订阅')
self.subscribes.delete(subscribe.id) self.subscribes.delete(subscribe.id)
@ -110,7 +118,7 @@ class SubscribeChain(ChainBase):
logger.warn(f'{subscribe.keyword or subscribe.name} 未搜索到资源') logger.warn(f'{subscribe.keyword or subscribe.name} 未搜索到资源')
continue continue
# 自动下载 # 自动下载
downloads, lefts = self.common.batch_download(contexts=contexts, need_tvs=no_exists) downloads, lefts = self.downloadchain.batch_download(contexts=contexts, need_tvs=no_exists)
if downloads and not lefts: if downloads and not lefts:
# 全部下载完成 # 全部下载完成
logger.info(f'{mediainfo.get_title_string()} 下载完成,完成订阅') logger.info(f'{mediainfo.get_title_string()} 下载完成,完成订阅')
@ -180,7 +188,7 @@ class SubscribeChain(ChainBase):
logger.warn(f'未识别到媒体信息,标题:{subscribe.name}tmdbid{subscribe.tmdbid}') logger.warn(f'未识别到媒体信息,标题:{subscribe.name}tmdbid{subscribe.tmdbid}')
continue continue
# 查询缺失的媒体信息 # 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=mediainfo) exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=mediainfo)
if exist_flag: if exist_flag:
logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在,完成订阅') logger.info(f'{mediainfo.get_title_string()} 媒体库中已存在,完成订阅')
self.subscribes.delete(subscribe.id) self.subscribes.delete(subscribe.id)
@ -206,7 +214,7 @@ class SubscribeChain(ChainBase):
logger.info(f'{mediainfo.get_title_string()} 匹配完成,共匹配到{len(_match_context)}个资源') logger.info(f'{mediainfo.get_title_string()} 匹配完成,共匹配到{len(_match_context)}个资源')
if _match_context: if _match_context:
# 批量择优下载 # 批量择优下载
downloads, lefts = self.common.batch_download(contexts=_match_context, need_tvs=no_exists) downloads, lefts = self.downloadchain.batch_download(contexts=_match_context, need_tvs=no_exists)
if downloads and not lefts: if downloads and not lefts:
# 全部下载完成 # 全部下载完成
logger.info(f'{mediainfo.get_title_string()} 下载完成,完成订阅') logger.info(f'{mediainfo.get_title_string()} 下载完成,完成订阅')

View File

@ -1,11 +1,11 @@
from typing import Any from typing import Any
from app.chain.common import * from app.chain.download import *
from app.chain.search import SearchChain from app.chain.search import SearchChain
from app.chain.subscribe import SubscribeChain
from app.core.context import MediaInfo, TorrentInfo from app.core.context import MediaInfo, TorrentInfo
from app.core.meta_info import MetaInfo from app.core.meta_info import MetaInfo
from app.core.event_manager import EventManager from app.core.event_manager import EventManager
from app.db.subscribes import Subscribes
from app.log import logger from app.log import logger
from app.utils.types import EventType from app.utils.types import EventType
@ -27,8 +27,8 @@ class UserMessageChain(ChainBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.common = CommonChain() self.downloadchain = DownloadChain()
self.subscribes = Subscribes() self.subscribechain = SubscribeChain()
self.searchchain = SearchChain() self.searchchain = SearchChain()
self.torrent = TorrentHelper() self.torrent = TorrentHelper()
self.eventmanager = EventManager() self.eventmanager = EventManager()
@ -43,6 +43,7 @@ class UserMessageChain(ChainBase):
return return
# 用户ID # 用户ID
userid = info.get('userid') userid = info.get('userid')
username = info.get('username')
if not userid: if not userid:
logger.debug(f'未识别到用户ID{body}{form}{args}') logger.debug(f'未识别到用户ID{body}{form}{args}')
return return
@ -60,7 +61,7 @@ class UserMessageChain(ChainBase):
"cmd": text "cmd": text
} }
) )
self.common.post_message(title=f"正在运行,请稍候 ...", userid=userid) self.post_message(title=f"正在运行,请稍候 ...", userid=userid)
elif text.isdigit(): elif text.isdigit():
# 缓存 # 缓存
@ -70,7 +71,7 @@ class UserMessageChain(ChainBase):
or not cache_data.get('items') \ or not cache_data.get('items') \
or len(cache_data.get('items')) < int(text): or len(cache_data.get('items')) < int(text):
# 发送消息 # 发送消息
self.common.post_message(title="输入有误!", userid=userid) self.post_message(title="输入有误!", userid=userid)
return return
# 缓存类型 # 缓存类型
cache_type: str = cache_data.get('type') cache_type: str = cache_data.get('type')
@ -84,17 +85,17 @@ class UserMessageChain(ChainBase):
exists: dict = self.media_exists(mediainfo=mediainfo) exists: dict = self.media_exists(mediainfo=mediainfo)
if exists: if exists:
# 已存在 # 已存在
self.common.post_message( self.post_message(
title=f"{mediainfo.type.value} {mediainfo.get_title_string()} 媒体库中已存在", userid=userid) title=f"{mediainfo.type.value} {mediainfo.get_title_string()} 媒体库中已存在", userid=userid)
return return
logger.info(f"{mediainfo.get_title_string()} 媒体库中不存在,开始搜索 ...") logger.info(f"{mediainfo.get_title_string()} 媒体库中不存在,开始搜索 ...")
self.common.post_message( self.post_message(
title=f"开始搜索 {mediainfo.type.value} {mediainfo.get_title_string()} ...", userid=userid) title=f"开始搜索 {mediainfo.type.value} {mediainfo.get_title_string()} ...", userid=userid)
# 搜索种子 # 搜索种子
contexts = self.searchchain.process(meta=self._current_meta, mediainfo=mediainfo) contexts = self.searchchain.process(meta=self._current_meta, mediainfo=mediainfo)
if not contexts: if not contexts:
# 没有数据 # 没有数据
self.common.post_message(title=f"{mediainfo.title} 未搜索到资源!", userid=userid) self.post_message(title=f"{mediainfo.title} 未搜索到资源!", userid=userid)
return return
# 更新缓存 # 更新缓存
self._user_cache[userid] = { self._user_cache[userid] = {
@ -109,34 +110,23 @@ class UserMessageChain(ChainBase):
elif cache_type == "Subscribe": elif cache_type == "Subscribe":
# 订阅媒体 # 订阅媒体
mediainfo: MediaInfo = cache_list[int(text) - 1] mediainfo: MediaInfo = cache_list[int(text) - 1]
# 补充识别媒体信息 self.subscribechain.process(title=mediainfo.title,
mediainfo: MediaInfo = self.recognize_media(meta=self._current_meta, tmdbid=mediainfo.tmdb_id) mtype=mediainfo.type,
if not mediainfo: tmdbid=mediainfo.tmdb_id,
logger.warn(f'未识别到媒体信息tmdbid{mediainfo.tmdb_id}') season=self._current_meta.begin_season,
return userid=userid,
self._current_media = mediainfo username=username)
state, msg = self.subscribes.add(mediainfo,
season=self._current_meta.begin_season)
if state:
# 订阅成功
self.common.post_message(
title=f"{mediainfo.get_title_string()} 已添加订阅",
image=mediainfo.get_message_image(),
userid=userid)
else:
# 订阅失败
self.common.post_message(title=f"{mediainfo.title} 添加订阅失败:{msg}", userid=userid)
elif cache_type == "Torrent": elif cache_type == "Torrent":
if int(text) == 0: if int(text) == 0:
# 自动选择下载 # 自动选择下载
# 查询缺失的媒体信息 # 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=self._current_media) exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=self._current_media)
if exist_flag: if exist_flag:
self.common.post_message(title=f"{self._current_media.get_title_string()} 媒体库中已存在", self.post_message(title=f"{self._current_media.get_title_string()} 媒体库中已存在",
userid=userid) userid=userid)
return return
# 批量下载 # 批量下载
downloads, lefts = self.common.batch_download(contexts=cache_list, downloads, lefts = self.downloadchain.batch_download(contexts=cache_list,
need_tvs=no_exists, need_tvs=no_exists,
userid=userid) userid=userid)
if downloads and not lefts: if downloads and not lefts:
@ -146,14 +136,12 @@ class UserMessageChain(ChainBase):
# 未完成下载 # 未完成下载
logger.info(f'{self._current_media.get_title_string()} 未下载未完整,添加订阅 ...') logger.info(f'{self._current_media.get_title_string()} 未下载未完整,添加订阅 ...')
# 添加订阅 # 添加订阅
state, msg = self.subscribes.add(self._current_media, self.subscribechain.process(title=self._current_media.title,
season=self._current_meta.begin_season) mtype=self._current_media.type,
if state: tmdbid=self._current_media.tmdb_id,
# 订阅成功 season=self._current_meta.begin_season,
self.common.post_message( userid=userid,
title=f"{self._current_media.get_title_string()} 已添加订阅", username=username)
text=f"来自用户:{userid}",
image=self._current_media.get_message_image())
else: else:
# 下载种子 # 下载种子
context: Context = cache_list[int(text) - 1] context: Context = cache_list[int(text) - 1]
@ -181,24 +169,24 @@ class UserMessageChain(ChainBase):
# 发送消息 # 发送消息
if not state: if not state:
# 下载失败 # 下载失败
self.common.post_message(title=f"{torrent.title} 添加下载失败!", self.post_message(title=f"{torrent.title} 添加下载失败!",
text=f"错误信息:{msg}", text=f"错误信息:{msg}",
userid=userid) userid=userid)
return return
# 下载成功,发送通知 # 下载成功,发送通知
self.common.post_download_message(meta=meta, mediainfo=self._current_media, torrent=torrent) self.downloadchain.post_download_message(meta=meta, mediainfo=self._current_media, torrent=torrent)
elif text.lower() == "p": elif text.lower() == "p":
# 上一页 # 上一页
cache_data: dict = self._user_cache.get(userid) cache_data: dict = self._user_cache.get(userid)
if not cache_data: if not cache_data:
# 没有缓存 # 没有缓存
self.common.post_message(title="输入有误!", userid=userid) self.post_message(title="输入有误!", userid=userid)
return return
if self._current_page == 0: if self._current_page == 0:
# 第一页 # 第一页
self.common.post_message(title="已经是第一页了!", userid=userid) self.post_message(title="已经是第一页了!", userid=userid)
return return
cache_type: str = cache_data.get('type') cache_type: str = cache_data.get('type')
cache_list: list = cache_data.get('items') cache_list: list = cache_data.get('items')
@ -222,7 +210,7 @@ class UserMessageChain(ChainBase):
cache_data: dict = self._user_cache.get(userid) cache_data: dict = self._user_cache.get(userid)
if not cache_data: if not cache_data:
# 没有缓存 # 没有缓存
self.common.post_message(title="输入有误!", userid=userid) self.post_message(title="输入有误!", userid=userid)
return return
cache_type: str = cache_data.get('type') cache_type: str = cache_data.get('type')
cache_list: list = cache_data.get('items') cache_list: list = cache_data.get('items')
@ -232,7 +220,7 @@ class UserMessageChain(ChainBase):
cache_list = cache_list[self._current_page * self._page_size:(self._current_page + 1) * self._page_size] cache_list = cache_list[self._current_page * self._page_size:(self._current_page + 1) * self._page_size]
if not cache_list: if not cache_list:
# 没有数据 # 没有数据
self.common.post_message(title="已经是最后一页了!", userid=userid) self.post_message(title="已经是最后一页了!", userid=userid)
return return
else: else:
if cache_type == "Torrent": if cache_type == "Torrent":
@ -257,7 +245,7 @@ class UserMessageChain(ChainBase):
# 识别 # 识别
meta = MetaInfo(title) meta = MetaInfo(title)
if not meta.get_name(): if not meta.get_name():
self.common.post_message(title="无法识别输入内容!", userid=userid) self.post_message(title="无法识别输入内容!", userid=userid)
return return
# 合并信息 # 合并信息
if mtype: if mtype:
@ -275,7 +263,7 @@ class UserMessageChain(ChainBase):
logger.info(f"开始搜索:{meta.get_name()}") logger.info(f"开始搜索:{meta.get_name()}")
medias: Optional[List[MediaInfo]] = self.search_medias(meta=meta) medias: Optional[List[MediaInfo]] = self.search_medias(meta=meta)
if not medias: if not medias:
self.common.post_message(title=f"{meta.get_name()} 没有找到对应的媒体信息!", userid=userid) self.post_message(title=f"{meta.get_name()} 没有找到对应的媒体信息!", userid=userid)
return return
self._user_cache[userid] = { self._user_cache[userid] = {
'type': action, 'type': action,

View File

@ -2,7 +2,7 @@
from unittest import TestCase from unittest import TestCase
from app.chain.common import CommonChain from app.chain.download import DownloadChain
from app.chain.identify import IdentifyChain from app.chain.identify import IdentifyChain
@ -16,5 +16,5 @@ class RecognizeTest(TestCase):
def test_recognize(self): def test_recognize(self):
result = IdentifyChain().process(title="我和我的祖国 2019") result = IdentifyChain().process(title="我和我的祖国 2019")
self.assertEqual(str(result.media_info.tmdb_id), '612845') self.assertEqual(str(result.media_info.tmdb_id), '612845')
exists = CommonChain().get_no_exists_info(result.media_info) exists = DownloadChain().get_no_exists_info(result.media_info)
self.assertTrue(exists[0]) self.assertTrue(exists[0])