fix chain depends

This commit is contained in:
jxxghp
2023-06-09 23:44:55 +08:00
parent b9012c5666
commit 51da075a65
6 changed files with 71 additions and 80 deletions

View File

@ -1,11 +1,11 @@
from typing import Any
from app.chain.common import *
from app.chain.download import *
from app.chain.search import SearchChain
from app.chain.subscribe import SubscribeChain
from app.core.context import MediaInfo, TorrentInfo
from app.core.meta_info import MetaInfo
from app.core.event_manager import EventManager
from app.db.subscribes import Subscribes
from app.log import logger
from app.utils.types import EventType
@ -27,8 +27,8 @@ class UserMessageChain(ChainBase):
def __init__(self):
super().__init__()
self.common = CommonChain()
self.subscribes = Subscribes()
self.downloadchain = DownloadChain()
self.subscribechain = SubscribeChain()
self.searchchain = SearchChain()
self.torrent = TorrentHelper()
self.eventmanager = EventManager()
@ -43,6 +43,7 @@ class UserMessageChain(ChainBase):
return
# 用户ID
userid = info.get('userid')
username = info.get('username')
if not userid:
logger.debug(f'未识别到用户ID{body}{form}{args}')
return
@ -60,7 +61,7 @@ class UserMessageChain(ChainBase):
"cmd": text
}
)
self.common.post_message(title=f"正在运行,请稍候 ...", userid=userid)
self.post_message(title=f"正在运行,请稍候 ...", userid=userid)
elif text.isdigit():
# 缓存
@ -70,7 +71,7 @@ class UserMessageChain(ChainBase):
or not cache_data.get('items') \
or len(cache_data.get('items')) < int(text):
# 发送消息
self.common.post_message(title="输入有误!", userid=userid)
self.post_message(title="输入有误!", userid=userid)
return
# 缓存类型
cache_type: str = cache_data.get('type')
@ -84,17 +85,17 @@ class UserMessageChain(ChainBase):
exists: dict = self.media_exists(mediainfo=mediainfo)
if exists:
# 已存在
self.common.post_message(
self.post_message(
title=f"{mediainfo.type.value} {mediainfo.get_title_string()} 媒体库中已存在", userid=userid)
return
logger.info(f"{mediainfo.get_title_string()} 媒体库中不存在,开始搜索 ...")
self.common.post_message(
self.post_message(
title=f"开始搜索 {mediainfo.type.value} {mediainfo.get_title_string()} ...", userid=userid)
# 搜索种子
contexts = self.searchchain.process(meta=self._current_meta, mediainfo=mediainfo)
if not contexts:
# 没有数据
self.common.post_message(title=f"{mediainfo.title} 未搜索到资源!", userid=userid)
self.post_message(title=f"{mediainfo.title} 未搜索到资源!", userid=userid)
return
# 更新缓存
self._user_cache[userid] = {
@ -109,36 +110,25 @@ class UserMessageChain(ChainBase):
elif cache_type == "Subscribe":
# 订阅媒体
mediainfo: MediaInfo = cache_list[int(text) - 1]
# 补充识别媒体信息
mediainfo: MediaInfo = self.recognize_media(meta=self._current_meta, tmdbid=mediainfo.tmdb_id)
if not mediainfo:
logger.warn(f'未识别到媒体信息tmdbid{mediainfo.tmdb_id}')
return
self._current_media = mediainfo
state, msg = self.subscribes.add(mediainfo,
season=self._current_meta.begin_season)
if state:
# 订阅成功
self.common.post_message(
title=f"{mediainfo.get_title_string()} 已添加订阅",
image=mediainfo.get_message_image(),
userid=userid)
else:
# 订阅失败
self.common.post_message(title=f"{mediainfo.title} 添加订阅失败:{msg}", userid=userid)
self.subscribechain.process(title=mediainfo.title,
mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
season=self._current_meta.begin_season,
userid=userid,
username=username)
elif cache_type == "Torrent":
if int(text) == 0:
# 自动选择下载
# 查询缺失的媒体信息
exist_flag, no_exists = self.common.get_no_exists_info(mediainfo=self._current_media)
exist_flag, no_exists = self.downloadchain.get_no_exists_info(mediainfo=self._current_media)
if exist_flag:
self.common.post_message(title=f"{self._current_media.get_title_string()} 媒体库中已存在",
userid=userid)
self.post_message(title=f"{self._current_media.get_title_string()} 媒体库中已存在",
userid=userid)
return
# 批量下载
downloads, lefts = self.common.batch_download(contexts=cache_list,
need_tvs=no_exists,
userid=userid)
downloads, lefts = self.downloadchain.batch_download(contexts=cache_list,
need_tvs=no_exists,
userid=userid)
if downloads and not lefts:
# 全部下载完成
logger.info(f'{self._current_media.get_title_string()} 下载完成')
@ -146,14 +136,12 @@ class UserMessageChain(ChainBase):
# 未完成下载
logger.info(f'{self._current_media.get_title_string()} 未下载未完整,添加订阅 ...')
# 添加订阅
state, msg = self.subscribes.add(self._current_media,
season=self._current_meta.begin_season)
if state:
# 订阅成功
self.common.post_message(
title=f"{self._current_media.get_title_string()} 已添加订阅",
text=f"来自用户:{userid}",
image=self._current_media.get_message_image())
self.subscribechain.process(title=self._current_media.title,
mtype=self._current_media.type,
tmdbid=self._current_media.tmdb_id,
season=self._current_meta.begin_season,
userid=userid,
username=username)
else:
# 下载种子
context: Context = cache_list[int(text) - 1]
@ -181,24 +169,24 @@ class UserMessageChain(ChainBase):
# 发送消息
if not state:
# 下载失败
self.common.post_message(title=f"{torrent.title} 添加下载失败!",
text=f"错误信息:{msg}",
userid=userid)
self.post_message(title=f"{torrent.title} 添加下载失败!",
text=f"错误信息:{msg}",
userid=userid)
return
# 下载成功,发送通知
self.common.post_download_message(meta=meta, mediainfo=self._current_media, torrent=torrent)
self.downloadchain.post_download_message(meta=meta, mediainfo=self._current_media, torrent=torrent)
elif text.lower() == "p":
# 上一页
cache_data: dict = self._user_cache.get(userid)
if not cache_data:
# 没有缓存
self.common.post_message(title="输入有误!", userid=userid)
self.post_message(title="输入有误!", userid=userid)
return
if self._current_page == 0:
# 第一页
self.common.post_message(title="已经是第一页了!", userid=userid)
self.post_message(title="已经是第一页了!", userid=userid)
return
cache_type: str = cache_data.get('type')
cache_list: list = cache_data.get('items')
@ -222,7 +210,7 @@ class UserMessageChain(ChainBase):
cache_data: dict = self._user_cache.get(userid)
if not cache_data:
# 没有缓存
self.common.post_message(title="输入有误!", userid=userid)
self.post_message(title="输入有误!", userid=userid)
return
cache_type: str = cache_data.get('type')
cache_list: list = cache_data.get('items')
@ -232,7 +220,7 @@ class UserMessageChain(ChainBase):
cache_list = cache_list[self._current_page * self._page_size:(self._current_page + 1) * self._page_size]
if not cache_list:
# 没有数据
self.common.post_message(title="已经是最后一页了!", userid=userid)
self.post_message(title="已经是最后一页了!", userid=userid)
return
else:
if cache_type == "Torrent":
@ -257,7 +245,7 @@ class UserMessageChain(ChainBase):
# 识别
meta = MetaInfo(title)
if not meta.get_name():
self.common.post_message(title="无法识别输入内容!", userid=userid)
self.post_message(title="无法识别输入内容!", userid=userid)
return
# 合并信息
if mtype:
@ -275,7 +263,7 @@ class UserMessageChain(ChainBase):
logger.info(f"开始搜索:{meta.get_name()}")
medias: Optional[List[MediaInfo]] = self.search_medias(meta=meta)
if not medias:
self.common.post_message(title=f"{meta.get_name()} 没有找到对应的媒体信息!", userid=userid)
self.post_message(title=f"{meta.get_name()} 没有找到对应的媒体信息!", userid=userid)
return
self._user_cache[userid] = {
'type': action,