This commit is contained in:
jxxghp 2023-06-15 21:23:15 +08:00
parent fbfce9df52
commit 60526dbd2d
7 changed files with 71 additions and 36 deletions

View File

@ -1,6 +1,6 @@
import re
from pathlib import Path
from typing import List, Optional, Tuple, Set, Dict
from typing import List, Optional, Tuple, Set, Dict, Union
from app.chain import ChainBase
from app.core.context import MediaInfo, TorrentInfo, Context
@ -474,7 +474,7 @@ class DownloadChain(ChainBase):
# 全部存在
return True, no_exists
def get_downloading(self):
def get_downloading(self, userid: Union[str, int] = None):
"""
查询正在下载的任务并发送消息
"""
@ -491,4 +491,4 @@ class DownloadChain(ChainBase):
f"{StringUtils.str_filesize(torrent.size)} "
f"{round(torrent.progress * 100, 1)}%")
index += 1
self.post_message(title=title, text="\n".join(messages))
self.post_message(title=title, text="\n".join(messages), userid=userid)

View File

@ -1,3 +1,5 @@
from typing import Union
from app.chain import ChainBase
from app.core.config import settings
from app.db.site_oper import SiteOper
@ -18,7 +20,7 @@ class SiteMessageChain(ChainBase):
self._siteoper = SiteOper()
self._cookiehelper = CookieHelper()
def process(self):
def process(self, userid: Union[str, int] = None):
"""
查询所有站点发送消息
"""
@ -40,9 +42,9 @@ class SiteMessageChain(ChainBase):
else:
messages.append(f"{site.id}. {site.name}")
# 发送列表
self.post_message(title=title, text="\n".join(messages))
self.post_message(title=title, text="\n".join(messages), userid=userid)
def disable(self, arg_str):
def disable(self, arg_str, userid: Union[str, int] = None):
"""
禁用站点
"""
@ -54,7 +56,7 @@ class SiteMessageChain(ChainBase):
site_id = int(arg_str)
site = self._siteoper.get(site_id)
if not site:
self.post_message(title=f"站点编号 {site_id} 不存在!")
self.post_message(title=f"站点编号 {site_id} 不存在!", userid=userid)
return
# 禁用站点
self._siteoper.update(site_id, {
@ -63,7 +65,7 @@ class SiteMessageChain(ChainBase):
# 重新发送消息
self.process()
def enable(self, arg_str):
def enable(self, arg_str, userid: Union[str, int] = None):
"""
启用站点
"""
@ -75,7 +77,7 @@ class SiteMessageChain(ChainBase):
site_id = int(arg_str)
site = self._siteoper.get(site_id)
if not site:
self.post_message(title=f"站点编号 {site_id} 不存在!")
self.post_message(title=f"站点编号 {site_id} 不存在!", userid=userid)
return
# 禁用站点
self._siteoper.update(site_id, {
@ -84,30 +86,30 @@ class SiteMessageChain(ChainBase):
# 重新发送消息
self.process()
def get_cookie(self, arg_str: str):
def get_cookie(self, arg_str: str, userid: Union[str, int] = None):
"""
使用用户名密码更新站点Cookie
"""
err_title = "请输入正确的命令格式:/site_cookie [id] [username] [password]" \
"[id]为站点编号,[uername]为站点用户名,[password]为站点密码"
if not arg_str:
self.post_message(title=err_title)
self.post_message(title=err_title, userid=userid)
return
arg_str = arg_str.strip()
args = arg_str.split()
if len(args) != 3:
self.post_message(title=err_title)
self.post_message(title=err_title, userid=userid)
return
site_id = args[0]
if not site_id.isdigit():
self.post_message(title=err_title)
self.post_message(title=err_title, userid=userid)
return
# 站点ID
site_id = int(site_id)
# 站点信息
site_info = self._siteoper.get(site_id)
if not site_info:
self.post_message(title=f"站点编号 {site_id} 不存在!")
self.post_message(title=f"站点编号 {site_id} 不存在!", userid=userid)
return
# 用户名
username = args[1]
@ -125,10 +127,12 @@ class SiteMessageChain(ChainBase):
if not cookie:
logger.error(msg)
self.post_message(title=f"{site_info.name}】 Cookie&UA更新失败",
text=f"错误原因:{msg}")
text=f"错误原因:{msg}",
userid=userid)
return
self._siteoper.update(site_id, {
"cookie": cookie,
"ua": ua
})
self.post_message(title=f"{site_info.name}】 Cookie&UA更新成功")
self.post_message(title=f"{site_info.name}】 Cookie&UA更新成功",
userid=userid)

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from app.chain import ChainBase
from app.chain.download import DownloadChain
@ -295,13 +295,13 @@ class SubscribeChain(ChainBase):
"lack_episode": len(left_episodes)
})
def list(self):
def list(self, userid: Union[str, int] = None):
"""
查询订阅并发送消息
"""
subscribes = self.subscribehelper.list()
if not subscribes:
self.post_message(title='没有任何订阅!')
self.post_message(title='没有任何订阅!', userid=userid)
return
title = f"共有 {len(subscribes)} 个订阅,回复对应指令操作: " \
f"\n- 删除订阅:/subscribe_delete [id]"
@ -317,9 +317,9 @@ class SubscribeChain(ChainBase):
f"_{subscribe.total_episode - (subscribe.lack_episode or subscribe.total_episode)}"
f"/{subscribe.total_episode}_")
# 发送列表
self.post_message(title=title, text='\n'.join(messages))
self.post_message(title=title, text='\n'.join(messages), userid=userid)
def delete(self, arg_str: str):
def delete(self, arg_str: str, userid: Union[str, int] = None):
"""
删除订阅
"""
@ -331,7 +331,7 @@ class SubscribeChain(ChainBase):
subscribe_id = int(arg_str)
subscribe = self.subscribehelper.get(subscribe_id)
if not subscribe:
self.post_message(title=f"订阅编号 {subscribe_id} 不存在!")
self.post_message(title=f"订阅编号 {subscribe_id} 不存在!", userid=userid)
return
# 删除订阅
self.subscribehelper.delete(subscribe_id)

View File

@ -1,6 +1,6 @@
import re
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Union
from app.chain import ChainBase
from app.core.context import MediaInfo
@ -18,7 +18,7 @@ class TransferChain(ChainBase):
文件转移处理链
"""
def process(self, arg_str: str = None) -> bool:
def process(self, arg_str: str = None, userid: Union[str, int] = None) -> bool:
"""
获取下载器中的种子列表并执行转移
"""
@ -87,7 +87,8 @@ class TransferChain(ChainBase):
if not mediainfo:
logger.warn(f'未识别到媒体信息,标题:{torrent.title}')
self.post_message(title=f"{torrent.title} 未识别到媒体信息,无法入库!\n"
f"回复:```\n/transfer {torrent.hash} [tmdbid]\n``` 手动识别转移。")
f"回复:```\n/transfer {torrent.hash} [tmdbid]\n``` 手动识别转移。",
userid=userid)
continue
else:
mediainfo = arg_mediainfo
@ -101,7 +102,8 @@ class TransferChain(ChainBase):
self.post_message(
title=f"{mediainfo.title_year}{meta.season_episode} 入库失败!",
text=f"原因:{transferinfo.message if transferinfo else '未知'}",
image=mediainfo.get_message_image()
image=mediainfo.get_message_image(),
userid=userid
),
continue
# 转移完成

View File

@ -58,7 +58,8 @@ class UserMessageChain(ChainBase):
self.eventmanager.send_event(
EventType.CommandExcute,
{
"cmd": text
"cmd": text,
"user": userid
}
)
self.post_message(title=f"正在运行,请稍候 ...", userid=userid)

View File

@ -1,6 +1,7 @@
import inspect
import traceback
from threading import Thread, Event
from typing import Any
from typing import Any, Union
from app.chain import ChainBase
from app.chain.cookiecloud import CookieCloudChain
@ -13,6 +14,7 @@ from app.core.event import eventmanager, EventManager
from app.core.plugin import PluginManager
from app.core.event import Event as ManagerEvent
from app.log import logger
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
from app.utils.types import EventType
@ -173,22 +175,29 @@ class Command(metaclass=Singleton):
"""
return self._commands.get(cmd, {})
def execute(self, cmd: str, data_str: str = "") -> None:
def execute(self, cmd: str, data_str: str = "", userid: Union[str, int] = None) -> None:
"""
执行命令
"""
command = self.get(cmd)
if command:
try:
logger.info(f"开始执行:{command.get('description')} ...")
logger.info(f"用户 {userid} 开始执行:{command.get('description')} ...")
cmd_data = command['data'] if command.get('data') else {}
if cmd_data:
command['func'](**cmd_data)
elif data_str:
command['func'](data_str)
if ObjectUtils.has_arguments(command['func']):
if cmd_data:
# 使用内置参数
command['func'](**cmd_data)
elif data_str:
# 使用用户输入参数
command['func'](data_str, userid)
else:
# 没有用户输入参数
command['func'](userid)
else:
# 没有参数
command['func']()
logger.info(f"{command.get('description')} 执行完成")
logger.info(f"用户 {userid} {command.get('description')} 执行完成")
except Exception as err:
logger.error(f"执行命令 {cmd} 出错:{str(err)}")
traceback.print_exc()
@ -208,9 +217,12 @@ class Command(metaclass=Singleton):
"cmd": "/xxx args"
}
"""
# 命令参数
event_str = event.event_data.get('cmd')
# 消息用户
event_user = event.event_data.get('user')
if event_str:
cmd = event_str.split()[0]
args = " ".join(event_str.split()[1:])
if self.get(cmd):
self.execute(cmd, args)
self.execute(cmd, args, event_user)

View File

@ -1,3 +1,4 @@
import inspect
from typing import Any
@ -9,3 +10,18 @@ class ObjectUtils:
return True
else:
return str(obj).startswith("{") or str(obj).startswith("[")
@staticmethod
def has_arguments(func):
"""
判断函数是否有参数
"""
signature = inspect.signature(func)
parameters = signature.parameters
parameter_names = list(parameters.keys())
# 排除 self 参数
if parameter_names and parameter_names[0] == 'self':
parameter_names = parameter_names[1:]
return len(parameter_names) > 0