diff --git a/app/api/endpoints/messages.py b/app/api/endpoints/messages.py index a5ee8b3b..f7b5e142 100644 --- a/app/api/endpoints/messages.py +++ b/app/api/endpoints/messages.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Any from fastapi import APIRouter, BackgroundTasks from fastapi import Request @@ -12,11 +12,11 @@ from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt router = APIRouter() -def start_message_chain(request: Request): +def start_message_chain(body: Any, form: Any, args: Any): """ 启动链式任务 """ - UserMessageChain().process(request) + UserMessageChain().process(body=body, form=form, args=args) @router.post("/", response_model=schemas.Response) @@ -24,7 +24,10 @@ async def user_message(background_tasks: BackgroundTasks, request: Request): """ 用户消息响应 """ - background_tasks.add_task(start_message_chain, request) + body = await request.body() + form = await request.form() + args = request.query_params + background_tasks.add_task(start_message_chain, body, form, args) return {"success": True} diff --git a/app/chain/common.py b/app/chain/common.py index bae11a33..bcc7fa1f 100644 --- a/app/chain/common.py +++ b/app/chain/common.py @@ -150,7 +150,7 @@ class CommonChain(ChainBase): for nt in need_tvs.get(tmdbid): if cur == nt.get("season") or (cur == 1 and not nt.get("season")): need_tvs[tmdbid].remove(nt) - if not need_tvs.get(tmdbid): + if not need_tvs.get(tmdbid) and need_tvs.get(tmdbid) is not None: need_tvs.pop(tmdbid) return need @@ -163,7 +163,7 @@ class CommonChain(ChainBase): need_tvs[tmdbid][seq]["episodes"] = need else: need_tvs[tmdbid].pop(seq) - if not need_tvs.get(tmdbid): + if not need_tvs.get(tmdbid) and need_tvs.get(tmdbid) is not None: need_tvs.pop(tmdbid) return need diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 25bab402..f5272e85 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -61,7 +61,7 @@ class SubscribeChain(ChainBase): else: logger.error(f'{mediainfo.get_title_string()} 添加订阅成功') self.common.post_message(title=f"{mediainfo.get_title_string()} 已添加订阅", - text="用户:{username}", + text=f"来自用户:{username}", image=mediainfo.get_message_image()) # 返回结果 return state diff --git a/app/chain/user_message.py b/app/chain/user_message.py index c91e654b..becedfe3 100644 --- a/app/chain/user_message.py +++ b/app/chain/user_message.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Any from fastapi import Request @@ -33,23 +33,23 @@ class UserMessageChain(ChainBase): self.searchchain = SearchChain() self.torrent = TorrentHelper() - def process(self, request: Request, *args, **kwargs) -> None: + def process(self, body: Any, form: Any, args: Any) -> None: """ 识别消息内容,执行操作 """ # 获取消息内容 - info: dict = self.run_module('message_parser', request=request) + info: dict = self.run_module('message_parser', body=body, form=form, args=args) if not info: return # 用户ID userid = info.get('userid') if not userid: - logger.debug(f'未识别到用户ID:{request}') + logger.debug(f'未识别到用户ID:{body}{form}{args}') return # 消息内容 text = str(info.get('text')).strip() if info.get('text') else None if not text: - logger.debug(f'未识别到消息内容:{request}') + logger.debug(f'未识别到消息内容::{body}{form}{args}') return logger.info(f'收到用户消息内容,用户:{userid},内容:{text}') if text.startswith('/'): @@ -76,7 +76,7 @@ class UserMessageChain(ChainBase): cache_list: list = cache_data.get('items') # 选择 if cache_type == "Search": - mediainfo: MediaInfo = cache_list[int(text) - 1] + mediainfo: MediaInfo = cache_list[int(text) + self._current_page * self._page_size - 1] self._current_media = mediainfo # 检查是否已存在 exists: list = self.run_module('media_exists', mediainfo=mediainfo) @@ -86,6 +86,8 @@ class UserMessageChain(ChainBase): title=f"{mediainfo.type.value} {mediainfo.get_title_string()} 媒体库中已存在", userid=userid) return logger.info(f"{mediainfo.get_title_string()} 媒体库中不存在,开始搜索 ...") + self.common.post_message( + title=f"开始搜索 {mediainfo.type.value} {mediainfo.get_title_string()} ...", userid=userid) # 搜索种子 contexts = self.searchchain.process(meta=self._current_meta, mediainfo=mediainfo) if not contexts: @@ -100,7 +102,7 @@ class UserMessageChain(ChainBase): self._current_page = 0 # 发送种子数据 logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...") - self.__post_torrents_message(items=contexts[:self._page_size], userid=userid) + self.__post_torrents_message(items=contexts[:self._page_size], userid=userid, total=len(contexts)) elif cache_type == "Subscribe": # 订阅媒体 @@ -144,11 +146,12 @@ class UserMessageChain(ChainBase): # 订阅成功 self.common.post_message( title=f"{self._current_media.get_title_string()} 已添加订阅", - text=f"用户:{userid}", + text=f"来自用户:{userid}", image=self._current_media.get_message_image()) else: # 下载种子 - torrent: TorrentInfo = cache_list[int(text) - 1] + context: Context = cache_list[int(text) - 1] + torrent: TorrentInfo = context.torrent_info logger.info(f"开始下载种子:{torrent.title} - {torrent.enclosure}") meta: MetaBase = MetaInfo(torrent.title) torrent_file, _, _, _, error_msg = self.torrent.download_torrent( @@ -205,10 +208,10 @@ class UserMessageChain(ChainBase): end = start + self._page_size if cache_type == "Torrent": # 发送种子数据 - self.__post_torrents_message(items=cache_list[start:end], userid=userid) + self.__post_torrents_message(items=cache_list[start:end], userid=userid, total=len(cache_list)) else: # 发送媒体数据 - self.__post_medias_message(items=cache_list[start:end], userid=userid) + self.__post_medias_message(items=cache_list[start:end], userid=userid, total=len(cache_list)) elif text.lower() == "n": # 下一页 @@ -219,9 +222,10 @@ class UserMessageChain(ChainBase): return cache_type: str = cache_data.get('type') cache_list: list = cache_data.get('items') + total = len(cache_list) # 加一页 self._current_page += 1 - cache_list = cache_list[self._current_page * self._page_size:] + cache_list = cache_list[self._current_page * self._page_size:(self._current_page + 1) * self._page_size] if not cache_list: # 没有数据 self.common.post_message(title="已经是最后一页了!", userid=userid) @@ -229,10 +233,10 @@ class UserMessageChain(ChainBase): else: if cache_type == "Torrent": # 发送种子数据 - self.__post_torrents_message(items=cache_list, userid=userid) + self.__post_torrents_message(items=cache_list, userid=userid, total=total) else: # 发送媒体数据 - self.__post_medias_message(items=cache_list, userid=userid) + self.__post_medias_message(items=cache_list, userid=userid, total=total) else: # 搜索或订阅 @@ -274,22 +278,22 @@ class UserMessageChain(ChainBase): self._current_page = 0 self._current_media = None # 发送媒体列表 - self.__post_medias_message(items=medias[:self._page_size], userid=userid) + self.__post_medias_message(items=medias[:self._page_size], userid=userid, total=len(medias)) - def __post_medias_message(self, items: list, userid: str): + def __post_medias_message(self, items: list, userid: str, total: int): """ 发送媒体列表消息 """ self.run_module('post_medias_message', - title="请回复数字选择对应媒体(p:上一页, n:下一页)", + title=f"共找到{total}条相关信息,请回复数字选择对应媒体(p:上一页 n:下一页)", items=items, userid=userid) - def __post_torrents_message(self, items: list, userid: str): + def __post_torrents_message(self, items: list, userid: str, total: int): """ 发送种子列表消息 """ self.run_module('post_torrents_message', - title="请回复数字下载对应资源(0:自动选择, p:上一页, n:下一页)", + title=f"共找到{total}条相关信息,请回复数字下载对应资源(0:自动选择 p:上一页 n:下一页)", items=items, userid=userid) diff --git a/app/modules/__init__.py b/app/modules/__init__.py index 57b3201c..11fbcb8c 100644 --- a/app/modules/__init__.py +++ b/app/modules/__init__.py @@ -1,11 +1,10 @@ from abc import abstractmethod, ABCMeta from pathlib import Path -from typing import Optional, List, Tuple, Union, Set +from typing import Optional, List, Tuple, Union, Set, Any -from fastapi import Request from ruamel.yaml import CommentedMap -from app.core.context import MediaInfo, TorrentInfo +from app.core.context import MediaInfo, TorrentInfo, Context from app.core.meta import MetaBase from app.utils.types import TorrentStatus @@ -59,13 +58,15 @@ class _ModuleBase(metaclass=ABCMeta): """ pass - def message_parser(self, request: Request) -> Optional[dict]: + def message_parser(self, body: Any, form: Any, args: Any) -> Optional[dict]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 - :param request: 请求体 + :param body: 请求体 + :param form: 表单 + :param args: 参数 :return: 消息内容、用户ID """ pass @@ -205,7 +206,7 @@ class _ModuleBase(metaclass=ABCMeta): """ pass - def post_torrents_message(self, title: str, items: List[TorrentInfo], + def post_torrents_message(self, title: str, items: List[Context], userid: Union[str, int] = None) -> Optional[bool]: """ 发送种子信息选择列表 diff --git a/app/modules/telegram/__init__.py b/app/modules/telegram/__init__.py index 9492780b..97311ac5 100644 --- a/app/modules/telegram/__init__.py +++ b/app/modules/telegram/__init__.py @@ -1,8 +1,7 @@ -from typing import Optional, Union, List, Tuple +import json +from typing import Optional, Union, List, Tuple, Any -from fastapi import Request - -from app.core import MediaInfo, TorrentInfo, settings +from app.core import MediaInfo, settings, Context from app.log import logger from app.modules import _ModuleBase from app.modules.telegram.telegram import Telegram @@ -18,13 +17,15 @@ class TelegramModule(_ModuleBase): def init_setting(self) -> Tuple[str, Union[str, bool]]: return "MESSAGER", "telegram" - async def message_parser(self, request: Request) -> Optional[dict]: + def message_parser(self, body: Any, form: Any, args: Any) -> Optional[dict]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 - :param request: 请求体 + :param body: 请求体 + :param form: 表单 + :param args: 参数 :return: 消息内容、用户ID """ """ @@ -50,7 +51,11 @@ class TelegramModule(_ModuleBase): } } """ - msg_json: dict = await request.json() + try: + msg_json: dict = json.loads(body) + except Exception as err: + logger.error(f"解析Telegram消息失败:{err}") + return None if msg_json: message = msg_json.get("message", {}) text = message.get("text") @@ -61,12 +66,15 @@ class TelegramModule(_ModuleBase): logger.info(f"收到Telegram消息:userid={user_id}, username={user_name}, text={text}") # 检查权限 if text.startswith("/"): - if str(user_id) not in settings.TELEGRAM_ADMINS.split(',') \ + if settings.TELEGRAM_ADMINS \ + and str(user_id) not in settings.TELEGRAM_ADMINS.split(',') \ and str(user_id) != settings.TELEGRAM_CHAT_ID: self.telegram.send_msg(title="只有管理员才有权限执行此命令", userid=user_id) return {} else: - if not str(user_id) in settings.TELEGRAM_USERS.split(','): + if settings.TELEGRAM_USERS \ + and not str(user_id) in settings.TELEGRAM_USERS.split(','): + logger.info(f"用户{user_id}不在用户白名单中,无法使用此机器人") self.telegram.send_msg(title="你不在用户白名单中,无法使用此机器人", userid=user_id) return {} return { @@ -99,7 +107,7 @@ class TelegramModule(_ModuleBase): """ return self.telegram.send_meidas_msg(title=title, medias=items, userid=userid) - def post_torrents_message(self, title: str, items: List[TorrentInfo], + def post_torrents_message(self, title: str, items: List[Context], userid: Union[str, int] = None) -> Optional[bool]: """ 发送种子信息选择列表 diff --git a/app/modules/telegram/telegram.py b/app/modules/telegram/telegram.py index 895a21e9..c5b711c3 100644 --- a/app/modules/telegram/telegram.py +++ b/app/modules/telegram/telegram.py @@ -2,7 +2,7 @@ from threading import Event, Thread from typing import Optional, List from urllib.parse import urlencode -from app.core import settings, MediaInfo, TorrentInfo +from app.core import settings, MediaInfo, TorrentInfo, Context from app.log import logger from app.utils.http import RequestUtils from app.utils.singleton import Singleton @@ -28,6 +28,7 @@ class Telegram(metaclass=Singleton): # 消息轮循 if self._telegram_token and self._telegram_chat_id: self._thread = Thread(target=self.__start_telegram_message_proxy) + self._thread.start() def send_msg(self, title: str, text: str = "", image: str = "", userid: str = "") -> Optional[bool]: """ @@ -46,19 +47,10 @@ class Telegram(metaclass=Singleton): return False try: - # text中的Markdown特殊字符转义 - text = text.replace("[", r"\[").replace("_", r"\_").replace("*", r"\*").replace("`", r"\`") - # 拼装消息内容 - titles = str(title).split('\n') - if len(titles) > 1: - title = titles[0] - if not text: - text = "\n".join(titles[1:]) - else: - text = "%s\n%s" % ("\n".join(titles[1:]), text) - if text: - caption = "*%s*\n%s" % (title, text.replace("\n\n", "\n")) + # text中的Markdown特殊字符转义 + text = text.replace("[", r"\[").replace("_", r"\_").replace("*", r"\*").replace("`", r"\`").replace("\n\n", "\n") + caption = f"*{title}*\n{text}" else: caption = title @@ -85,19 +77,19 @@ class Telegram(metaclass=Singleton): for media in medias: if not image: image = media.get_message_image() - if media.get_vote_string(): + if media.vote_average: caption = "%s\n%s. [%s](%s)\n%s,%s" % (caption, index, media.get_title_string(), media.get_detail_url(), - media.get_type_string(), - media.get_vote_string()) + f"类型:{media.type.value}", + f"评分:{media.vote_average}") else: caption = "%s\n%s. [%s](%s)\n%s" % (caption, index, media.get_title_string(), media.get_detail_url(), - media.get_type_string()) + f"类型:{media.type.value}") index += 1 if userid: @@ -111,7 +103,7 @@ class Telegram(metaclass=Singleton): logger.error(f"发送消息失败:{msg_e}") return False - def send_torrents_msg(self, torrents: List[TorrentInfo], userid: str = "", title: str = "") -> Optional[bool]: + def send_torrents_msg(self, torrents: List[Context], userid: str = "", title: str = "") -> Optional[bool]: """ 发送列表消息 """ @@ -120,7 +112,8 @@ class Telegram(metaclass=Singleton): try: index, caption = 1, "*%s*" % title - for torrent in torrents: + for context in torrents: + torrent = context.torrent_info link = torrent.page_url title = torrent.title free = torrent.get_volume_factor_string() @@ -167,18 +160,13 @@ class Telegram(metaclass=Singleton): # 发送图文消息 if image: - res = request.get_res("https://api.telegram.org/bot%s/sendPhoto?" % self._telegram_token + urlencode( - {"chat_id": chat_id, "photo": image, "caption": caption, "parse_mode": "Markdown"})) - if __res_parse(res): - return True - else: - photo_req = request.get_res(image) - if photo_req and photo_req.content: - res = request.post_res("https://api.telegram.org/bot%s/sendPhoto" % self._telegram_token, - data={"chat_id": chat_id, "caption": caption, "parse_mode": "Markdown"}, - files={"photo": photo_req.content}) - if __res_parse(res): - return True + photo_req = request.get_res(image) + if photo_req and photo_req.content: + res = request.post_res("https://api.telegram.org/bot%s/sendPhoto" % self._telegram_token, + data={"chat_id": chat_id, "caption": caption, "parse_mode": "Markdown"}, + files={"photo": photo_req.content}) + if __res_parse(res): + return True # 发送文本消息 res = request.get_res("https://api.telegram.org/bot%s/sendMessage?" % self._telegram_token + urlencode( {"chat_id": chat_id, "text": caption, "parse_mode": "Markdown"})) diff --git a/app/modules/wechat/__init__.py b/app/modules/wechat/__init__.py index 5d0b9562..d0259d6a 100644 --- a/app/modules/wechat/__init__.py +++ b/app/modules/wechat/__init__.py @@ -1,8 +1,7 @@ -from typing import Optional, Union, List, Tuple - -from fastapi import Request import xml.dom.minidom -from app.core import MediaInfo, TorrentInfo, settings +from typing import Optional, Union, List, Tuple, Any + +from app.core import MediaInfo, settings, Context from app.log import logger from app.modules import _ModuleBase from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt @@ -20,26 +19,31 @@ class WechatModule(_ModuleBase): def init_setting(self) -> Tuple[str, Union[str, bool]]: return "MESSAGER", "wechat" - def message_parser(self, request: Request) -> Optional[dict]: + def message_parser(self, body: Any, form: Any, args: Any) -> Optional[dict]: """ 解析消息内容,返回字典,注意以下约定值: userid: 用户ID username: 用户名 text: 内容 - :param request: 请求体 + :param body: 请求体 + :param form: 表单 + :param args: 参数 :return: 消息内容、用户ID """ try: # URL参数 - sVerifyMsgSig = request.query_params.get("msg_signature") - sVerifyTimeStamp = request.query_params.get("timestamp") - sVerifyNonce = request.query_params.get("nonce") + sVerifyMsgSig = args.get("msg_signature") + sVerifyTimeStamp = args.get("timestamp") + sVerifyNonce = args.get("nonce") + if not sVerifyMsgSig or not sVerifyTimeStamp or not sVerifyNonce: + logger.error(f"微信请求参数错误:{args}") + return None # 解密模块 wxcpt = WXBizMsgCrypt(sToken=settings.WECHAT_TOKEN, sEncodingAESKey=settings.WECHAT_ENCODING_AESKEY, sReceiveId=settings.WECHAT_CORPID) # 报文数据 - sReqData = request.form() + sReqData = form if not sReqData: return None logger.debug(f"收到微信请求:{sReqData}") @@ -132,7 +136,7 @@ class WechatModule(_ModuleBase): # 再发送内容 return self.wechat.send_medias_msg(medias=items, userid=userid) - def post_torrents_message(self, title: str, items: List[TorrentInfo], + def post_torrents_message(self, title: str, items: List[Context], userid: Union[str, int] = None) -> Optional[bool]: """ 发送种子信息选择列表 diff --git a/app/modules/wechat/wechat.py b/app/modules/wechat/wechat.py index 8fd68b3a..337f57e2 100644 --- a/app/modules/wechat/wechat.py +++ b/app/modules/wechat/wechat.py @@ -3,7 +3,7 @@ import threading from datetime import datetime from typing import Optional, List -from app.core import settings, MediaInfo, TorrentInfo +from app.core import settings, MediaInfo, Context from app.log import logger from app.utils.http import RequestUtils from app.utils.singleton import Singleton @@ -165,10 +165,10 @@ class WeChat(metaclass=Singleton): articles = [] index = 1 for media in medias: - if media.get_vote_string(): - title = f"{index}. {media.get_title_string()}\n{media.get_type_string()},{media.get_vote_string()}" + if media.vote_average: + title = f"{index}. {media.get_title_string()}\n类型:{media.type.value},评分:{media.vote_average}" else: - title = f"{index}. {media.get_title_string()}\n{media.get_type_string()}" + title = f"{index}. {media.get_title_string()}\n类型:{media.type.value}" articles.append({ "title": title, "description": "", @@ -187,7 +187,7 @@ class WeChat(metaclass=Singleton): } return self.__post_request(message_url, req_json) - def send_torrents_msg(self, torrents: List[TorrentInfo], userid: str = "", title: str = "") -> Optional[bool]: + def send_torrents_msg(self, torrents: List[Context], userid: str = "", title: str = "") -> Optional[bool]: """ 发送列表消息 """ @@ -197,7 +197,8 @@ class WeChat(metaclass=Singleton): try: index, caption = 1, "*%s*" % title - for torrent in torrents: + for context in torrents: + torrent = context.torrent_info link = torrent.page_url title = torrent.title free = torrent.get_volume_factor_string()