From a188bff80243057c306663b57028435420e66a31 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sun, 6 Aug 2023 19:31:41 +0800 Subject: [PATCH] =?UTF-8?q?feat=20chatgpt=E6=8F=92=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/chain/message.py | 73 +++++++++++++++++++++------------ app/plugins/chatgpt/__init__.py | 14 ++++++- 2 files changed, 60 insertions(+), 27 deletions(-) diff --git a/app/chain/message.py b/app/chain/message.py index 34e70986..949519f8 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -283,36 +283,57 @@ class MessageChain(ChainBase): # 订阅 content = re.sub(r"订阅[::\s]*", "", text) action = "Subscribe" + elif text.startswith("#") \ + or re.search(r"^请[问帮你]", text) \ + or re.search(r"[??]$", text) \ + or StringUtils.count_words(text) > 10 \ + or text.find("继续") != -1: + # 聊天 + content = text + action = "chat" else: # 搜索 content = re.sub(r"(搜索|下载)[::\s]*", "", text) action = "Search" - # 搜索 - meta, medias = self.medtachain.search(content) - # 识别 - if not meta.name: - self.post_message(Notification( - channel=channel, title="无法识别输入内容!", userid=userid)) - return - # 开始搜索 - if not medias: - self.post_message(Notification( - channel=channel, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid)) - return - logger.info(f"搜索到 {len(medias)} 条相关媒体信息") - # 记录当前状态 - _current_meta = meta - user_cache[userid] = { - 'type': action, - 'items': medias - } - _current_page = 0 - _current_media = None - # 发送媒体列表 - self.__post_medias_message(channel=channel, - title=meta.name, - items=medias[:self._page_size], - userid=userid, total=len(medias)) + + if action in ["Subscribe", "Search"]: + # 搜索 + meta, medias = self.medtachain.search(content) + # 识别 + if not meta.name: + self.post_message(Notification( + channel=channel, title="无法识别输入内容!", userid=userid)) + return + # 开始搜索 + if not medias: + self.post_message(Notification( + channel=channel, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid)) + return + logger.info(f"搜索到 {len(medias)} 条相关媒体信息") + # 记录当前状态 + _current_meta = meta + user_cache[userid] = { + 'type': action, + 'items': medias + } + _current_page = 0 + _current_media = None + # 发送媒体列表 + self.__post_medias_message(channel=channel, + title=meta.name, + items=medias[:self._page_size], + userid=userid, total=len(medias)) + else: + # 广播事件 + self.eventmanager.send_event( + EventType.UserMessage, + { + "text": content, + "user": userid, + "channel": channel + } + ) + # 保存缓存 self.save_cache(user_cache, self._cache_file) diff --git a/app/plugins/chatgpt/__init__.py b/app/plugins/chatgpt/__init__.py index 82e878ba..c0e06958 100644 --- a/app/plugins/chatgpt/__init__.py +++ b/app/plugins/chatgpt/__init__.py @@ -2,6 +2,7 @@ from typing import Any, List, Dict, Tuple from app.core.event import eventmanager from app.plugins import _PluginBase +from app.plugins.chatgpt.openai import OpenAi from app.schemas.types import EventType @@ -28,6 +29,7 @@ class ChatGPT(_PluginBase): auth_level = 1 # 私有属性 + openai = None _enabled = False _openai_url = None _openai_key = None @@ -37,6 +39,7 @@ class ChatGPT(_PluginBase): self._enabled = config.get("enabled") self._openai_url = config.get("openai_url") self._openai_key = config.get("openai_key") + self.openai = OpenAi(api_key=self._openai_key, api_url=self._openai_url) def get_state(self) -> bool: return self._enabled @@ -131,7 +134,16 @@ class ChatGPT(_PluginBase): """ 监听用户消息,获取ChatGPT回复 """ - pass + if not self.openai: + return + text = event.event_data.get("text") + userid = event.event_data.get("userid") + channel = event.event_data.get("channel") + if not text: + return + response = self.openai.get_response(text=text, userid=userid) + if response: + self.post_message(channel=channel, title=text, userid=userid) def stop_service(self): """