diff --git a/app/chain/download.py b/app/chain/download.py index 0dcf079b..e8126477 100644 --- a/app/chain/download.py +++ b/app/chain/download.py @@ -1,6 +1,6 @@ import re from pathlib import Path -from typing import List, Optional, Tuple, Set, Dict +from typing import List, Optional, Tuple, Set, Dict, Union from app.chain import ChainBase from app.core.context import MediaInfo, TorrentInfo, Context @@ -474,7 +474,7 @@ class DownloadChain(ChainBase): # 全部存在 return True, no_exists - def get_downloading(self): + def get_downloading(self, userid: Union[str, int] = None): """ 查询正在下载的任务,并发送消息 """ @@ -491,4 +491,4 @@ class DownloadChain(ChainBase): f"{StringUtils.str_filesize(torrent.size)} " f"{round(torrent.progress * 100, 1)}%") index += 1 - self.post_message(title=title, text="\n".join(messages)) + self.post_message(title=title, text="\n".join(messages), userid=userid) diff --git a/app/chain/site_message.py b/app/chain/site_message.py index e19592cc..23ad3fb5 100644 --- a/app/chain/site_message.py +++ b/app/chain/site_message.py @@ -1,3 +1,5 @@ +from typing import Union + from app.chain import ChainBase from app.core.config import settings from app.db.site_oper import SiteOper @@ -18,7 +20,7 @@ class SiteMessageChain(ChainBase): self._siteoper = SiteOper() self._cookiehelper = CookieHelper() - def process(self): + def process(self, userid: Union[str, int] = None): """ 查询所有站点,发送消息 """ @@ -40,9 +42,9 @@ class SiteMessageChain(ChainBase): else: messages.append(f"{site.id}. {site.name}") # 发送列表 - self.post_message(title=title, text="\n".join(messages)) + self.post_message(title=title, text="\n".join(messages), userid=userid) - def disable(self, arg_str): + def disable(self, arg_str, userid: Union[str, int] = None): """ 禁用站点 """ @@ -54,7 +56,7 @@ class SiteMessageChain(ChainBase): site_id = int(arg_str) site = self._siteoper.get(site_id) if not site: - self.post_message(title=f"站点编号 {site_id} 不存在!") + self.post_message(title=f"站点编号 {site_id} 不存在!", userid=userid) return # 禁用站点 self._siteoper.update(site_id, { @@ -63,7 +65,7 @@ class SiteMessageChain(ChainBase): # 重新发送消息 self.process() - def enable(self, arg_str): + def enable(self, arg_str, userid: Union[str, int] = None): """ 启用站点 """ @@ -75,7 +77,7 @@ class SiteMessageChain(ChainBase): site_id = int(arg_str) site = self._siteoper.get(site_id) if not site: - self.post_message(title=f"站点编号 {site_id} 不存在!") + self.post_message(title=f"站点编号 {site_id} 不存在!", userid=userid) return # 禁用站点 self._siteoper.update(site_id, { @@ -84,30 +86,30 @@ class SiteMessageChain(ChainBase): # 重新发送消息 self.process() - def get_cookie(self, arg_str: str): + def get_cookie(self, arg_str: str, userid: Union[str, int] = None): """ 使用用户名密码更新站点Cookie """ err_title = "请输入正确的命令格式:/site_cookie [id] [username] [password]," \ "[id]为站点编号,[uername]为站点用户名,[password]为站点密码" if not arg_str: - self.post_message(title=err_title) + self.post_message(title=err_title, userid=userid) return arg_str = arg_str.strip() args = arg_str.split() if len(args) != 3: - self.post_message(title=err_title) + self.post_message(title=err_title, userid=userid) return site_id = args[0] if not site_id.isdigit(): - self.post_message(title=err_title) + self.post_message(title=err_title, userid=userid) return # 站点ID site_id = int(site_id) # 站点信息 site_info = self._siteoper.get(site_id) if not site_info: - self.post_message(title=f"站点编号 {site_id} 不存在!") + self.post_message(title=f"站点编号 {site_id} 不存在!", userid=userid) return # 用户名 username = args[1] @@ -125,10 +127,12 @@ class SiteMessageChain(ChainBase): if not cookie: logger.error(msg) self.post_message(title=f"【{site_info.name}】 Cookie&UA更新失败!", - text=f"错误原因:{msg}") + text=f"错误原因:{msg}", + userid=userid) return self._siteoper.update(site_id, { "cookie": cookie, "ua": ua }) - self.post_message(title=f"【{site_info.name}】 Cookie&UA更新成功") + self.post_message(title=f"【{site_info.name}】 Cookie&UA更新成功", + userid=userid) diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index bced38fc..75309d93 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from app.chain import ChainBase from app.chain.download import DownloadChain @@ -295,13 +295,13 @@ class SubscribeChain(ChainBase): "lack_episode": len(left_episodes) }) - def list(self): + def list(self, userid: Union[str, int] = None): """ 查询订阅并发送消息 """ subscribes = self.subscribehelper.list() if not subscribes: - self.post_message(title='没有任何订阅!') + self.post_message(title='没有任何订阅!', userid=userid) return title = f"共有 {len(subscribes)} 个订阅,回复对应指令操作: " \ f"\n- 删除订阅:/subscribe_delete [id]" @@ -317,9 +317,9 @@ class SubscribeChain(ChainBase): f"_{subscribe.total_episode - (subscribe.lack_episode or subscribe.total_episode)}" f"/{subscribe.total_episode}_") # 发送列表 - self.post_message(title=title, text='\n'.join(messages)) + self.post_message(title=title, text='\n'.join(messages), userid=userid) - def delete(self, arg_str: str): + def delete(self, arg_str: str, userid: Union[str, int] = None): """ 删除订阅 """ @@ -331,7 +331,7 @@ class SubscribeChain(ChainBase): subscribe_id = int(arg_str) subscribe = self.subscribehelper.get(subscribe_id) if not subscribe: - self.post_message(title=f"订阅编号 {subscribe_id} 不存在!") + self.post_message(title=f"订阅编号 {subscribe_id} 不存在!", userid=userid) return # 删除订阅 self.subscribehelper.delete(subscribe_id) diff --git a/app/chain/transfer.py b/app/chain/transfer.py index 124c48bb..f041d464 100644 --- a/app/chain/transfer.py +++ b/app/chain/transfer.py @@ -1,6 +1,6 @@ import re from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union from app.chain import ChainBase from app.core.context import MediaInfo @@ -18,7 +18,7 @@ class TransferChain(ChainBase): 文件转移处理链 """ - def process(self, arg_str: str = None) -> bool: + def process(self, arg_str: str = None, userid: Union[str, int] = None) -> bool: """ 获取下载器中的种子列表,并执行转移 """ @@ -87,7 +87,8 @@ class TransferChain(ChainBase): if not mediainfo: logger.warn(f'未识别到媒体信息,标题:{torrent.title}') self.post_message(title=f"{torrent.title} 未识别到媒体信息,无法入库!\n" - f"回复:```\n/transfer {torrent.hash} [tmdbid]\n``` 手动识别转移。") + f"回复:```\n/transfer {torrent.hash} [tmdbid]\n``` 手动识别转移。", + userid=userid) continue else: mediainfo = arg_mediainfo @@ -101,7 +102,8 @@ class TransferChain(ChainBase): self.post_message( title=f"{mediainfo.title_year}{meta.season_episode} 入库失败!", text=f"原因:{transferinfo.message if transferinfo else '未知'}", - image=mediainfo.get_message_image() + image=mediainfo.get_message_image(), + userid=userid ), continue # 转移完成 diff --git a/app/chain/user_message.py b/app/chain/user_message.py index 138fad48..e4a6a05e 100644 --- a/app/chain/user_message.py +++ b/app/chain/user_message.py @@ -58,7 +58,8 @@ class UserMessageChain(ChainBase): self.eventmanager.send_event( EventType.CommandExcute, { - "cmd": text + "cmd": text, + "user": userid } ) self.post_message(title=f"正在运行,请稍候 ...", userid=userid) diff --git a/app/command.py b/app/command.py index e47903c4..74c9c35e 100644 --- a/app/command.py +++ b/app/command.py @@ -1,6 +1,7 @@ +import inspect import traceback from threading import Thread, Event -from typing import Any +from typing import Any, Union from app.chain import ChainBase from app.chain.cookiecloud import CookieCloudChain @@ -13,6 +14,7 @@ from app.core.event import eventmanager, EventManager from app.core.plugin import PluginManager from app.core.event import Event as ManagerEvent from app.log import logger +from app.utils.object import ObjectUtils from app.utils.singleton import Singleton from app.utils.types import EventType @@ -173,22 +175,29 @@ class Command(metaclass=Singleton): """ return self._commands.get(cmd, {}) - def execute(self, cmd: str, data_str: str = "") -> None: + def execute(self, cmd: str, data_str: str = "", userid: Union[str, int] = None) -> None: """ 执行命令 """ command = self.get(cmd) if command: try: - logger.info(f"开始执行:{command.get('description')} ...") + logger.info(f"用户 {userid} 开始执行:{command.get('description')} ...") cmd_data = command['data'] if command.get('data') else {} - if cmd_data: - command['func'](**cmd_data) - elif data_str: - command['func'](data_str) + if ObjectUtils.has_arguments(command['func']): + if cmd_data: + # 使用内置参数 + command['func'](**cmd_data) + elif data_str: + # 使用用户输入参数 + command['func'](data_str, userid) + else: + # 没有用户输入参数 + command['func'](userid) else: + # 没有参数 command['func']() - logger.info(f"{command.get('description')} 执行完成") + logger.info(f"用户 {userid} {command.get('description')} 执行完成") except Exception as err: logger.error(f"执行命令 {cmd} 出错:{str(err)}") traceback.print_exc() @@ -208,9 +217,12 @@ class Command(metaclass=Singleton): "cmd": "/xxx args" } """ + # 命令参数 event_str = event.event_data.get('cmd') + # 消息用户 + event_user = event.event_data.get('user') if event_str: cmd = event_str.split()[0] args = " ".join(event_str.split()[1:]) if self.get(cmd): - self.execute(cmd, args) + self.execute(cmd, args, event_user) diff --git a/app/utils/object.py b/app/utils/object.py index 9265bfd8..09959ffd 100644 --- a/app/utils/object.py +++ b/app/utils/object.py @@ -1,3 +1,4 @@ +import inspect from typing import Any @@ -9,3 +10,18 @@ class ObjectUtils: return True else: return str(obj).startswith("{") or str(obj).startswith("[") + + @staticmethod + def has_arguments(func): + """ + 判断函数是否有参数 + """ + signature = inspect.signature(func) + parameters = signature.parameters + parameter_names = list(parameters.keys()) + + # 排除 self 参数 + if parameter_names and parameter_names[0] == 'self': + parameter_names = parameter_names[1:] + + return len(parameter_names) > 0