feat chatgpt插件

This commit is contained in:
jxxghp
2023-08-06 19:31:41 +08:00
parent 6a4b5e6549
commit a188bff802
2 changed files with 60 additions and 27 deletions

View File

@ -283,36 +283,57 @@ class MessageChain(ChainBase):
# 订阅 # 订阅
content = re.sub(r"订阅[:\s]*", "", text) content = re.sub(r"订阅[:\s]*", "", text)
action = "Subscribe" action = "Subscribe"
elif text.startswith("#") \
or re.search(r"^请[问帮你]", text) \
or re.search(r"[?]$", text) \
or StringUtils.count_words(text) > 10 \
or text.find("继续") != -1:
# 聊天
content = text
action = "chat"
else: else:
# 搜索 # 搜索
content = re.sub(r"(搜索|下载)[:\s]*", "", text) content = re.sub(r"(搜索|下载)[:\s]*", "", text)
action = "Search" action = "Search"
# 搜索
meta, medias = self.medtachain.search(content) if action in ["Subscribe", "Search"]:
# 识别 # 搜索
if not meta.name: meta, medias = self.medtachain.search(content)
self.post_message(Notification( # 识别
channel=channel, title="无法识别输入内容!", userid=userid)) if not meta.name:
return self.post_message(Notification(
# 开始搜索 channel=channel, title="无法识别输入内容!", userid=userid))
if not medias: return
self.post_message(Notification( # 开始搜索
channel=channel, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid)) if not medias:
return self.post_message(Notification(
logger.info(f"搜索到 {len(medias)} 条相关媒体信息") channel=channel, title=f"{meta.name} 没有找到对应的媒体信息", userid=userid))
# 记录当前状态 return
_current_meta = meta logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
user_cache[userid] = { # 记录当前状态
'type': action, _current_meta = meta
'items': medias user_cache[userid] = {
} 'type': action,
_current_page = 0 'items': medias
_current_media = None }
# 发送媒体列表 _current_page = 0
self.__post_medias_message(channel=channel, _current_media = None
title=meta.name, # 发送媒体列表
items=medias[:self._page_size], self.__post_medias_message(channel=channel,
userid=userid, total=len(medias)) title=meta.name,
items=medias[:self._page_size],
userid=userid, total=len(medias))
else:
# 广播事件
self.eventmanager.send_event(
EventType.UserMessage,
{
"text": content,
"user": userid,
"channel": channel
}
)
# 保存缓存 # 保存缓存
self.save_cache(user_cache, self._cache_file) self.save_cache(user_cache, self._cache_file)

View File

@ -2,6 +2,7 @@ from typing import Any, List, Dict, Tuple
from app.core.event import eventmanager from app.core.event import eventmanager
from app.plugins import _PluginBase from app.plugins import _PluginBase
from app.plugins.chatgpt.openai import OpenAi
from app.schemas.types import EventType from app.schemas.types import EventType
@ -28,6 +29,7 @@ class ChatGPT(_PluginBase):
auth_level = 1 auth_level = 1
# 私有属性 # 私有属性
openai = None
_enabled = False _enabled = False
_openai_url = None _openai_url = None
_openai_key = None _openai_key = None
@ -37,6 +39,7 @@ class ChatGPT(_PluginBase):
self._enabled = config.get("enabled") self._enabled = config.get("enabled")
self._openai_url = config.get("openai_url") self._openai_url = config.get("openai_url")
self._openai_key = config.get("openai_key") self._openai_key = config.get("openai_key")
self.openai = OpenAi(api_key=self._openai_key, api_url=self._openai_url)
def get_state(self) -> bool: def get_state(self) -> bool:
return self._enabled return self._enabled
@ -131,7 +134,16 @@ class ChatGPT(_PluginBase):
""" """
监听用户消息获取ChatGPT回复 监听用户消息获取ChatGPT回复
""" """
pass if not self.openai:
return
text = event.event_data.get("text")
userid = event.event_data.get("userid")
channel = event.event_data.get("channel")
if not text:
return
response = self.openai.get_response(text=text, userid=userid)
if response:
self.post_message(channel=channel, title=text, userid=userid)
def stop_service(self): def stop_service(self):
""" """