add 插件API

This commit is contained in:
jxxghp
2023-06-18 10:22:14 +08:00
parent c0ea8097f8
commit 5dd7878e1b
17 changed files with 146 additions and 83 deletions

View File

@ -1,6 +1,6 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.endpoints import login, user, site, message, webhook, subscribe, media, douban, search from app.api.endpoints import login, user, site, message, webhook, subscribe, media, douban, search, plugin
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(login.router, tags=["login"]) api_router.include_router(login.router, tags=["login"])
@ -12,3 +12,4 @@ api_router.include_router(subscribe.router, prefix="/subscribe", tags=["subscrib
api_router.include_router(media.router, prefix="/media", tags=["media"]) api_router.include_router(media.router, prefix="/media", tags=["media"])
api_router.include_router(search.router, prefix="/search", tags=["search"]) api_router.include_router(search.router, prefix="/search", tags=["search"])
api_router.include_router(douban.router, prefix="/douban", tags=["douban"]) api_router.include_router(douban.router, prefix="/douban", tags=["douban"])
api_router.include_router(plugin.router, prefix="/plugin", tags=["plugin"])

View File

@ -25,4 +25,4 @@ async def sync_douban(
同步豆瓣想看 同步豆瓣想看
""" """
background_tasks.add_task(start_douban_chain) background_tasks.add_task(start_douban_chain)
return {"success": True} return schemas.Response(success=True, message="任务已启动")

View File

@ -31,9 +31,9 @@ async def login_access_token(
elif not user.is_active: elif not user.is_active:
raise HTTPException(status_code=400, detail="用户未启用") raise HTTPException(status_code=400, detail="用户未启用")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return { return schemas.Token(
"access_token": security.create_access_token( access_token=security.create_access_token(
user.id, expires_delta=access_token_expires user.id, expires_delta=access_token_expires
), ),
"token_type": "bearer", token_type="bearer",
} )

View File

@ -29,7 +29,7 @@ async def user_message(background_tasks: BackgroundTasks, request: Request):
form = await request.form() form = await request.form()
args = request.query_params args = request.query_params
background_tasks.add_task(start_message_chain, body, form, args) background_tasks.add_task(start_message_chain, body, form, args)
return {"success": True} return schemas.Response(success=True)
@router.get("/") @router.get("/")

View File

@ -0,0 +1,25 @@
from typing import Any
from fastapi import APIRouter, Depends
from app import schemas
from app.core.plugin import PluginManager
from app.db.models.user import User
from app.db.userauth import get_current_active_user
router = APIRouter()
@router.get("/", response_model=schemas.Response)
@router.post("/")
async def run_plugin_method(plugin_id: str, method: str,
_: User = Depends(get_current_active_user),
*args,
**kwargs) -> Any:
"""
运行插件方法
"""
return PluginManager().run_plugin_method(pid=plugin_id,
method=method,
*args,
**kwargs)

View File

@ -68,8 +68,8 @@ async def cookie_cloud_sync(_: User = Depends(get_current_active_user)) -> Any:
""" """
status, error_msg = CookieCloudChain().process() status, error_msg = CookieCloudChain().process()
if not status: if not status:
return {"success": False, "message": error_msg} schemas.Response(success=True, message=error_msg)
return {"success": True, "message": error_msg} return schemas.Response(success=True, message="同步成功!")
@router.get("/cookie", response_model=schemas.Response) @router.get("/cookie", response_model=schemas.Response)
@ -94,6 +94,6 @@ async def update_cookie(
username=username, username=username,
password=password) password=password)
if not status: if not status:
return {"success": False, "message": msg} return schemas.Response(success=False, message=msg)
else: else:
return {"success": True, "message": msg} return schemas.Response(success=True, message=msg)

View File

@ -44,7 +44,7 @@ async def create_subscribe(
新增订阅 新增订阅
""" """
result = SubscribeChain().add(**subscribe_in.dict()) result = SubscribeChain().add(**subscribe_in.dict())
return {"success": result} return schemas.Response(success=result)
@router.put("/", response_model=schemas.Subscribe) @router.put("/", response_model=schemas.Subscribe)
@ -78,7 +78,7 @@ async def delete_subscribe(
删除订阅信息 删除订阅信息
""" """
Subscribe.delete(db, subscribe_in.id) Subscribe.delete(db, subscribe_in.id)
return {"success": True} return schemas.Response(success=True)
@router.post("/seerr", response_model=schemas.Response) @router.post("/seerr", response_model=schemas.Response)
@ -100,12 +100,12 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
) )
notification_type = req_json.get("notification_type") notification_type = req_json.get("notification_type")
if notification_type not in ["MEDIA_APPROVED", "MEDIA_AUTO_APPROVED"]: if notification_type not in ["MEDIA_APPROVED", "MEDIA_AUTO_APPROVED"]:
return {"success": False, "message": "不支持的通知类型"} return schemas.Response(success=False, message="不支持的通知类型")
subject = req_json.get("subject") subject = req_json.get("subject")
media_type = MediaType.MOVIE if req_json.get("media", {}).get("media_type") == "movie" else MediaType.TV media_type = MediaType.MOVIE if req_json.get("media", {}).get("media_type") == "movie" else MediaType.TV
tmdbId = req_json.get("media", {}).get("tmdbId") tmdbId = req_json.get("media", {}).get("tmdbId")
if not media_type or not tmdbId or not subject: if not media_type or not tmdbId or not subject:
return {"success": False, "message": "请求参数不正确"} return schemas.Response(success=False, message="请求参数不正确")
user_name = req_json.get("request", {}).get("requestedBy_username") user_name = req_json.get("request", {}).get("requestedBy_username")
# 添加订阅 # 添加订阅
if media_type == MediaType.MOVIE: if media_type == MediaType.MOVIE:
@ -131,7 +131,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
season=season, season=season,
username=user_name) username=user_name)
return {"success": True} return schemas.Response(success=True)
@router.get("/refresh", response_model=schemas.Response) @router.get("/refresh", response_model=schemas.Response)
@ -141,7 +141,7 @@ async def refresh_subscribes(
刷新所有订阅 刷新所有订阅
""" """
SubscribeChain().refresh() SubscribeChain().refresh()
return {"success": True} return schemas.Response(success=True)
@router.get("/search", response_model=schemas.Response) @router.get("/search", response_model=schemas.Response)
@ -151,4 +151,4 @@ async def search_subscribes(
搜索所有订阅 搜索所有订阅
""" """
SubscribeChain().search(state='R') SubscribeChain().search(state='R')
return {"success": True} return schemas.Response(success=True)

View File

@ -90,7 +90,7 @@ async def delete_user(
detail="用户不存在", detail="用户不存在",
) )
user.delete_by_email(db, user_in.email) user.delete_by_email(db, user_in.email)
return {"success": True} return schemas.Response(success=True)
@router.get("/{user_id}", response_model=schemas.User) @router.get("/{user_id}", response_model=schemas.User)

View File

@ -23,9 +23,9 @@ async def webhook_message(background_tasks: BackgroundTasks,
Webhook响应 Webhook响应
""" """
if token != settings.API_TOKEN: if token != settings.API_TOKEN:
return {"success": False, "message": "token认证不通过"} return schemas.Response(success=False, message="token认证不通过")
body = await request.body() body = await request.body()
form = await request.form() form = await request.form()
args = request.query_params args = request.query_params
background_tasks.add_task(start_webhook_chain, body, form, args) background_tasks.add_task(start_webhook_chain, body, form, args)
return {"success": True} return schemas.Response(success=True)

View File

@ -340,7 +340,7 @@ async def arr_remove_movie(apikey: str, mid: int, db: Session = Depends(get_db))
subscribe = Subscribe.get(db, mid) subscribe = Subscribe.get(db, mid)
if subscribe: if subscribe:
subscribe.delete(db, mid) subscribe.delete(db, mid)
return {"success": True} return schemas.Response(success=True)
else: else:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
@ -660,7 +660,7 @@ async def arr_remove_series(apikey: str, tid: int, db: Session = Depends(get_db)
subscribe = Subscribe.get(db, tid) subscribe = Subscribe.get(db, tid)
if subscribe: if subscribe:
subscribe.delete(db, tid) subscribe.delete(db, tid)
return {"success": True} return schemas.Response(success=True)
else: else:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,

View File

@ -1,9 +1,9 @@
from types import FunctionType
from typing import Generator, Optional from typing import Generator, Optional
from app.core.config import settings from app.core.config import settings
from app.helper.module import ModuleHelper from app.helper.module import ModuleHelper
from app.log import logger from app.log import logger
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
@ -68,16 +68,9 @@ class ModuleManager(metaclass=Singleton):
""" """
获取模块列表 获取模块列表
""" """
def check_method(func: FunctionType) -> bool:
"""
检查函数是否已实现
"""
return func.__code__.co_code != b'd\x01S\x00'
if not self._running_modules: if not self._running_modules:
return [] return []
for _, module in self._running_modules.items(): for _, module in self._running_modules.items():
if hasattr(module, method) \ if hasattr(module, method) \
and check_method(getattr(module, method)): and ObjectUtils.check_method(getattr(module, method)):
yield module yield module

View File

@ -1,9 +1,10 @@
import traceback import traceback
from typing import List, Any from typing import List, Any, Dict
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.helper.module import ModuleHelper from app.helper.module import ModuleHelper
from app.log import logger from app.log import logger
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
@ -24,6 +25,7 @@ class PluginManager(metaclass=Singleton):
self.init_config() self.init_config()
def init_config(self): def init_config(self):
# 配置管理
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper()
# 停止已有插件 # 停止已有插件
self.stop() self.stop()
@ -32,21 +34,7 @@ class PluginManager(metaclass=Singleton):
def start(self): def start(self):
""" """
启动 启动加载插件
"""
# 加载插件
self.__load_plugins()
def stop(self):
"""
停止
"""
# 停止所有插件
self.__stop_plugins()
def __load_plugins(self):
"""
加载所有插件
""" """
# 扫描插件目录 # 扫描插件目录
plugins = ModuleHelper.load( plugins = ModuleHelper.load(
@ -59,32 +47,24 @@ class PluginManager(metaclass=Singleton):
self._plugins = {} self._plugins = {}
for plugin in plugins: for plugin in plugins:
plugin_id = plugin.__name__ plugin_id = plugin.__name__
self._plugins[plugin_id] = plugin
# 生成实例
self._running_plugins[plugin_id] = plugin()
# 初始化配置
self.reload_plugin(plugin_id)
logger.info(f"Plugin Loaded{plugin.__name__}")
def reload_plugin(self, pid: str):
"""
生效插件配置
"""
if not pid:
return
if not self._running_plugins.get(pid):
return
if hasattr(self._running_plugins[pid], "init_plugin"):
try: try:
self._running_plugins[pid].init_plugin(self.get_plugin_config(pid)) # 存储Class
logger.debug(f"生效插件配置:{pid}") self._plugins[plugin_id] = plugin
# 生成实例
plugin_obj = plugin()
# 生效插件配置
plugin_obj.init_plugin(self.get_plugin_config(plugin_id))
# 存储运行实例
self._running_plugins[plugin_id] = plugin_obj
logger.info(f"Plugin Loaded{plugin_id}")
except Exception as err: except Exception as err:
logger.error(f"加载插件 {pid} 出错:{err} - {traceback.format_exc()}") logger.error(f"加载插件 {plugin_id} 出错:{err} - {traceback.format_exc()}")
def __stop_plugins(self): def stop(self):
""" """
停止所有插件 停止
""" """
# 停止所有插件
for plugin in self._running_plugins.values(): for plugin in self._running_plugins.values():
if hasattr(plugin, "stop"): if hasattr(plugin, "stop"):
plugin.stop() plugin.stop()
@ -105,7 +85,7 @@ class PluginManager(metaclass=Singleton):
return False return False
return self.systemconfig.set(self._config_key % pid, conf) return self.systemconfig.set(self._config_key % pid, conf)
def get_plugin_commands(self) -> List[dict]: def get_plugin_commands(self) -> List[Dict[str, Any]]:
""" """
获取插件命令 获取插件命令
[{ [{
@ -117,8 +97,9 @@ class PluginManager(metaclass=Singleton):
""" """
ret_commands = [] ret_commands = []
for _, plugin in self._running_plugins.items(): for _, plugin in self._running_plugins.items():
if hasattr(plugin, "get_command"): if hasattr(plugin, "get_command") \
ret_commands.append(plugin.get_command()) and ObjectUtils.check_method(plugin.get_command):
ret_commands += plugin.get_command()
return ret_commands return ret_commands
def run_plugin_method(self, pid: str, method: str, *args, **kwargs) -> Any: def run_plugin_method(self, pid: str, method: str, *args, **kwargs) -> Any:

View File

@ -1,6 +1,6 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, List, Dict
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
@ -27,8 +27,6 @@ class _PluginBase(metaclass=ABCMeta):
- update_config() 更新配置信息 - update_config() 更新配置信息
- init_plugin() 生效配置信息 - init_plugin() 生效配置信息
- get_data_path() 获取插件数据保存目录 - get_data_path() 获取插件数据保存目录
- get_command() 获取插件命令,使用消息机制通过远程控制
""" """
# 插件名称 # 插件名称
plugin_name: str = "" plugin_name: str = ""
@ -48,6 +46,20 @@ class _PluginBase(metaclass=ABCMeta):
""" """
pass pass
@staticmethod
@abstractmethod
def get_command() -> List[Dict[str, Any]]:
"""
获取插件命令
[{
"cmd": "/xx",
"event": EventType.xx,
"desc": "xxxx",
"data": {}
}]
"""
pass
@abstractmethod @abstractmethod
def stop_service(self): def stop_service(self):
""" """

View File

@ -2,12 +2,13 @@ import traceback
from multiprocessing.dummy import Pool as ThreadPool from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from threading import Event from threading import Event
from typing import Any from typing import Any, List, Dict
from urllib.parse import urljoin from urllib.parse import urljoin
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from ruamel.yaml import CommentedMap from ruamel.yaml import CommentedMap
from app import schemas
from app.core.event import EventManager, eventmanager from app.core.event import EventManager, eventmanager
from app.core.config import settings from app.core.config import settings
from app.helper.browser import PlaywrightHelper from app.helper.browser import PlaywrightHelper
@ -18,6 +19,7 @@ from app.log import logger
from app.plugins import _PluginBase from app.plugins import _PluginBase
from app.utils.http import RequestUtils from app.utils.http import RequestUtils
from app.utils.site import SiteUtils from app.utils.site import SiteUtils
from app.utils.string import StringUtils
from app.utils.timer import TimerUtils from app.utils.timer import TimerUtils
from app.schemas.types import EventType from app.schemas.types import EventType
@ -64,17 +66,17 @@ class AutoSignIn(_PluginBase):
self._scheduler.start() self._scheduler.start()
@staticmethod @staticmethod
def get_command() -> dict: def get_command() -> List[Dict[str, Any]]:
""" """
定义远程控制命令 定义远程控制命令
:return: 命令关键字、事件、描述、附带数据 :return: 命令关键字、事件、描述、附带数据
""" """
return { return [{
"cmd": "/site_signin", "cmd": "/site_signin",
"event": EventType.SiteSignin, "event": EventType.SiteSignin,
"desc": "站点签到", "desc": "站点签到",
"data": {} "data": {}
} }]
@eventmanager.register(EventType.SiteSignin) @eventmanager.register(EventType.SiteSignin)
def sign_in(self, event: Event = None): def sign_in(self, event: Event = None):
@ -110,6 +112,23 @@ class AutoSignIn(_PluginBase):
logger.error("站点模块加载失败:%s" % str(e)) logger.error("站点模块加载失败:%s" % str(e))
return None return None
def signin_by_domain(self, url: str) -> schemas.Response:
"""
签到一个站点可由API调用
"""
domain = StringUtils.get_url_domain(url)
site_info = self.sites.get_indexer(domain)
if site_info:
return schemas.Response(
success=True,
message=f"站点【{url}】不存在"
)
else:
return schemas.Response(
success=True,
message=self.signin_site(site_info)
)
def signin_site(self, site_info: CommentedMap) -> str: def signin_site(self, site_info: CommentedMap) -> str:
""" """
签到一个站点 签到一个站点

View File

@ -1,12 +1,13 @@
from datetime import datetime from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool from multiprocessing.dummy import Pool as ThreadPool
from threading import Lock from threading import Lock
from typing import Optional, Any, List from typing import Optional, Any, List, Dict
import requests import requests
from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.schedulers.background import BackgroundScheduler
from ruamel.yaml import CommentedMap from ruamel.yaml import CommentedMap
from app import schemas
from app.core.config import settings from app.core.config import settings
from app.core.event import eventmanager from app.core.event import eventmanager
from app.core.event import Event from app.core.event import Event
@ -64,17 +65,17 @@ class SiteStatistic(_PluginBase):
self._scheduler.start() self._scheduler.start()
@staticmethod @staticmethod
def get_command() -> dict: def get_command() -> List[Dict[str, Any]]:
""" """
定义远程控制命令 定义远程控制命令
:return: 命令关键字、事件、描述、附带数据 :return: 命令关键字、事件、描述、附带数据
""" """
return { return [{
"cmd": "/site_statistic", "cmd": "/site_statistic",
"event": EventType.SiteStatistic, "event": EventType.SiteStatistic,
"desc": "站点数据统计", "desc": "站点数据统计",
"data": {} "data": {}
} }]
def stop_service(self): def stop_service(self):
pass pass
@ -181,6 +182,28 @@ class SiteStatistic(_PluginBase):
return site_schema(site_name, url, site_cookie, html_text, session=session, ua=ua, proxy=proxy) return site_schema(site_name, url, site_cookie, html_text, session=session, ua=ua, proxy=proxy)
return None return None
def refresh_by_domain(self, domain: str) -> schemas.Response:
"""
刷新一个站点数据可由API调用
"""
site_info = self.sites.get_indexer(domain)
if site_info:
site_data = self.__refresh_site_data(site_info)
if site_data:
return schemas.Response(
success=True,
message=f"站点 {domain} 刷新成功",
data=site_data.to_dict()
)
return schemas.Response(
success=False,
message=f"站点 {domain} 刷新数据失败,未获取到数据"
)
return schemas.Response(
success=False,
message=f"站点 {domain} 不存在"
)
def __refresh_site_data(self, site_info: CommentedMap) -> Optional[ISiteUserInfo]: def __refresh_site_data(self, site_info: CommentedMap) -> Optional[ISiteUserInfo]:
""" """
更新单个site 数据信息 更新单个site 数据信息

View File

@ -6,3 +6,4 @@ from pydantic import BaseModel
class Response(BaseModel): class Response(BaseModel):
success: bool success: bool
message: Optional[str] = None message: Optional[str] = None
data: Optional[dict] = {}

View File

@ -1,4 +1,5 @@
import inspect import inspect
from types import FunctionType
from typing import Any, Callable from typing import Any, Callable
@ -25,3 +26,10 @@ class ObjectUtils:
parameter_names = parameter_names[1:] parameter_names = parameter_names[1:]
return len(parameter_names) return len(parameter_names)
@staticmethod
def check_method(func: FunctionType) -> bool:
"""
检查函数是否已实现
"""
return func.__code__.co_code != b'd\x01S\x00'