add 插件API

This commit is contained in:
jxxghp 2023-06-18 10:22:14 +08:00
parent c0ea8097f8
commit 5dd7878e1b
17 changed files with 146 additions and 83 deletions

View File

@ -1,6 +1,6 @@
from fastapi import APIRouter
from app.api.endpoints import login, user, site, message, webhook, subscribe, media, douban, search
from app.api.endpoints import login, user, site, message, webhook, subscribe, media, douban, search, plugin
api_router = APIRouter()
api_router.include_router(login.router, tags=["login"])
@ -12,3 +12,4 @@ api_router.include_router(subscribe.router, prefix="/subscribe", tags=["subscrib
api_router.include_router(media.router, prefix="/media", tags=["media"])
api_router.include_router(search.router, prefix="/search", tags=["search"])
api_router.include_router(douban.router, prefix="/douban", tags=["douban"])
api_router.include_router(plugin.router, prefix="/plugin", tags=["plugin"])

View File

@ -25,4 +25,4 @@ async def sync_douban(
同步豆瓣想看
"""
background_tasks.add_task(start_douban_chain)
return {"success": True}
return schemas.Response(success=True, message="任务已启动")

View File

@ -31,9 +31,9 @@ async def login_access_token(
elif not user.is_active:
raise HTTPException(status_code=400, detail="用户未启用")
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return {
"access_token": security.create_access_token(
return schemas.Token(
access_token=security.create_access_token(
user.id, expires_delta=access_token_expires
),
"token_type": "bearer",
}
token_type="bearer",
)

View File

@ -29,7 +29,7 @@ async def user_message(background_tasks: BackgroundTasks, request: Request):
form = await request.form()
args = request.query_params
background_tasks.add_task(start_message_chain, body, form, args)
return {"success": True}
return schemas.Response(success=True)
@router.get("/")

View File

@ -0,0 +1,25 @@
from typing import Any
from fastapi import APIRouter, Depends
from app import schemas
from app.core.plugin import PluginManager
from app.db.models.user import User
from app.db.userauth import get_current_active_user
router = APIRouter()
@router.get("/", response_model=schemas.Response)
@router.post("/")
async def run_plugin_method(plugin_id: str, method: str,
_: User = Depends(get_current_active_user),
*args,
**kwargs) -> Any:
"""
运行插件方法
"""
return PluginManager().run_plugin_method(pid=plugin_id,
method=method,
*args,
**kwargs)

View File

@ -68,8 +68,8 @@ async def cookie_cloud_sync(_: User = Depends(get_current_active_user)) -> Any:
"""
status, error_msg = CookieCloudChain().process()
if not status:
return {"success": False, "message": error_msg}
return {"success": True, "message": error_msg}
schemas.Response(success=True, message=error_msg)
return schemas.Response(success=True, message="同步成功!")
@router.get("/cookie", response_model=schemas.Response)
@ -94,6 +94,6 @@ async def update_cookie(
username=username,
password=password)
if not status:
return {"success": False, "message": msg}
return schemas.Response(success=False, message=msg)
else:
return {"success": True, "message": msg}
return schemas.Response(success=True, message=msg)

View File

@ -44,7 +44,7 @@ async def create_subscribe(
新增订阅
"""
result = SubscribeChain().add(**subscribe_in.dict())
return {"success": result}
return schemas.Response(success=result)
@router.put("/", response_model=schemas.Subscribe)
@ -78,7 +78,7 @@ async def delete_subscribe(
删除订阅信息
"""
Subscribe.delete(db, subscribe_in.id)
return {"success": True}
return schemas.Response(success=True)
@router.post("/seerr", response_model=schemas.Response)
@ -100,12 +100,12 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
)
notification_type = req_json.get("notification_type")
if notification_type not in ["MEDIA_APPROVED", "MEDIA_AUTO_APPROVED"]:
return {"success": False, "message": "不支持的通知类型"}
return schemas.Response(success=False, message="不支持的通知类型")
subject = req_json.get("subject")
media_type = MediaType.MOVIE if req_json.get("media", {}).get("media_type") == "movie" else MediaType.TV
tmdbId = req_json.get("media", {}).get("tmdbId")
if not media_type or not tmdbId or not subject:
return {"success": False, "message": "请求参数不正确"}
return schemas.Response(success=False, message="请求参数不正确")
user_name = req_json.get("request", {}).get("requestedBy_username")
# 添加订阅
if media_type == MediaType.MOVIE:
@ -131,7 +131,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
season=season,
username=user_name)
return {"success": True}
return schemas.Response(success=True)
@router.get("/refresh", response_model=schemas.Response)
@ -141,7 +141,7 @@ async def refresh_subscribes(
刷新所有订阅
"""
SubscribeChain().refresh()
return {"success": True}
return schemas.Response(success=True)
@router.get("/search", response_model=schemas.Response)
@ -151,4 +151,4 @@ async def search_subscribes(
搜索所有订阅
"""
SubscribeChain().search(state='R')
return {"success": True}
return schemas.Response(success=True)

View File

@ -90,7 +90,7 @@ async def delete_user(
detail="用户不存在",
)
user.delete_by_email(db, user_in.email)
return {"success": True}
return schemas.Response(success=True)
@router.get("/{user_id}", response_model=schemas.User)

View File

@ -23,9 +23,9 @@ async def webhook_message(background_tasks: BackgroundTasks,
Webhook响应
"""
if token != settings.API_TOKEN:
return {"success": False, "message": "token认证不通过"}
return schemas.Response(success=False, message="token认证不通过")
body = await request.body()
form = await request.form()
args = request.query_params
background_tasks.add_task(start_webhook_chain, body, form, args)
return {"success": True}
return schemas.Response(success=True)

View File

@ -340,7 +340,7 @@ async def arr_remove_movie(apikey: str, mid: int, db: Session = Depends(get_db))
subscribe = Subscribe.get(db, mid)
if subscribe:
subscribe.delete(db, mid)
return {"success": True}
return schemas.Response(success=True)
else:
raise HTTPException(
status_code=404,
@ -660,7 +660,7 @@ async def arr_remove_series(apikey: str, tid: int, db: Session = Depends(get_db)
subscribe = Subscribe.get(db, tid)
if subscribe:
subscribe.delete(db, tid)
return {"success": True}
return schemas.Response(success=True)
else:
raise HTTPException(
status_code=404,

View File

@ -1,9 +1,9 @@
from types import FunctionType
from typing import Generator, Optional
from app.core.config import settings
from app.helper.module import ModuleHelper
from app.log import logger
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
@ -68,16 +68,9 @@ class ModuleManager(metaclass=Singleton):
"""
获取模块列表
"""
def check_method(func: FunctionType) -> bool:
"""
检查函数是否已实现
"""
return func.__code__.co_code != b'd\x01S\x00'
if not self._running_modules:
return []
for _, module in self._running_modules.items():
if hasattr(module, method) \
and check_method(getattr(module, method)):
and ObjectUtils.check_method(getattr(module, method)):
yield module

View File

@ -1,9 +1,10 @@
import traceback
from typing import List, Any
from typing import List, Any, Dict
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.module import ModuleHelper
from app.log import logger
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
@ -24,6 +25,7 @@ class PluginManager(metaclass=Singleton):
self.init_config()
def init_config(self):
# 配置管理
self.systemconfig = SystemConfigOper()
# 停止已有插件
self.stop()
@ -32,21 +34,7 @@ class PluginManager(metaclass=Singleton):
def start(self):
"""
启动
"""
# 加载插件
self.__load_plugins()
def stop(self):
"""
停止
"""
# 停止所有插件
self.__stop_plugins()
def __load_plugins(self):
"""
加载所有插件
启动加载插件
"""
# 扫描插件目录
plugins = ModuleHelper.load(
@ -59,32 +47,24 @@ class PluginManager(metaclass=Singleton):
self._plugins = {}
for plugin in plugins:
plugin_id = plugin.__name__
self._plugins[plugin_id] = plugin
# 生成实例
self._running_plugins[plugin_id] = plugin()
# 初始化配置
self.reload_plugin(plugin_id)
logger.info(f"Plugin Loaded{plugin.__name__}")
def reload_plugin(self, pid: str):
"""
生效插件配置
"""
if not pid:
return
if not self._running_plugins.get(pid):
return
if hasattr(self._running_plugins[pid], "init_plugin"):
try:
self._running_plugins[pid].init_plugin(self.get_plugin_config(pid))
logger.debug(f"生效插件配置:{pid}")
# 存储Class
self._plugins[plugin_id] = plugin
# 生成实例
plugin_obj = plugin()
# 生效插件配置
plugin_obj.init_plugin(self.get_plugin_config(plugin_id))
# 存储运行实例
self._running_plugins[plugin_id] = plugin_obj
logger.info(f"Plugin Loaded{plugin_id}")
except Exception as err:
logger.error(f"加载插件 {pid} 出错:{err} - {traceback.format_exc()}")
logger.error(f"加载插件 {plugin_id} 出错:{err} - {traceback.format_exc()}")
def __stop_plugins(self):
def stop(self):
"""
停止所有插件
停止
"""
# 停止所有插件
for plugin in self._running_plugins.values():
if hasattr(plugin, "stop"):
plugin.stop()
@ -105,7 +85,7 @@ class PluginManager(metaclass=Singleton):
return False
return self.systemconfig.set(self._config_key % pid, conf)
def get_plugin_commands(self) -> List[dict]:
def get_plugin_commands(self) -> List[Dict[str, Any]]:
"""
获取插件命令
[{
@ -117,8 +97,9 @@ class PluginManager(metaclass=Singleton):
"""
ret_commands = []
for _, plugin in self._running_plugins.items():
if hasattr(plugin, "get_command"):
ret_commands.append(plugin.get_command())
if hasattr(plugin, "get_command") \
and ObjectUtils.check_method(plugin.get_command):
ret_commands += plugin.get_command()
return ret_commands
def run_plugin_method(self, pid: str, method: str, *args, **kwargs) -> Any:

View File

@ -1,6 +1,6 @@
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Any
from typing import Any, List, Dict
from app.chain import ChainBase
from app.core.config import settings
@ -27,8 +27,6 @@ class _PluginBase(metaclass=ABCMeta):
- update_config() 更新配置信息
- init_plugin() 生效配置信息
- get_data_path() 获取插件数据保存目录
- get_command() 获取插件命令使用消息机制通过远程控制
"""
# 插件名称
plugin_name: str = ""
@ -48,6 +46,20 @@ class _PluginBase(metaclass=ABCMeta):
"""
pass
@staticmethod
@abstractmethod
def get_command() -> List[Dict[str, Any]]:
"""
获取插件命令
[{
"cmd": "/xx",
"event": EventType.xx,
"desc": "xxxx",
"data": {}
}]
"""
pass
@abstractmethod
def stop_service(self):
"""

View File

@ -2,12 +2,13 @@ import traceback
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing.pool import ThreadPool
from threading import Event
from typing import Any
from typing import Any, List, Dict
from urllib.parse import urljoin
from apscheduler.schedulers.background import BackgroundScheduler
from ruamel.yaml import CommentedMap
from app import schemas
from app.core.event import EventManager, eventmanager
from app.core.config import settings
from app.helper.browser import PlaywrightHelper
@ -18,6 +19,7 @@ from app.log import logger
from app.plugins import _PluginBase
from app.utils.http import RequestUtils
from app.utils.site import SiteUtils
from app.utils.string import StringUtils
from app.utils.timer import TimerUtils
from app.schemas.types import EventType
@ -64,17 +66,17 @@ class AutoSignIn(_PluginBase):
self._scheduler.start()
@staticmethod
def get_command() -> dict:
def get_command() -> List[Dict[str, Any]]:
"""
定义远程控制命令
:return: 命令关键字事件描述附带数据
"""
return {
return [{
"cmd": "/site_signin",
"event": EventType.SiteSignin,
"desc": "站点签到",
"data": {}
}
}]
@eventmanager.register(EventType.SiteSignin)
def sign_in(self, event: Event = None):
@ -110,6 +112,23 @@ class AutoSignIn(_PluginBase):
logger.error("站点模块加载失败:%s" % str(e))
return None
def signin_by_domain(self, url: str) -> schemas.Response:
"""
签到一个站点可由API调用
"""
domain = StringUtils.get_url_domain(url)
site_info = self.sites.get_indexer(domain)
if site_info:
return schemas.Response(
success=True,
message=f"站点【{url}】不存在"
)
else:
return schemas.Response(
success=True,
message=self.signin_site(site_info)
)
def signin_site(self, site_info: CommentedMap) -> str:
"""
签到一个站点

View File

@ -1,12 +1,13 @@
from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool
from threading import Lock
from typing import Optional, Any, List
from typing import Optional, Any, List, Dict
import requests
from apscheduler.schedulers.background import BackgroundScheduler
from ruamel.yaml import CommentedMap
from app import schemas
from app.core.config import settings
from app.core.event import eventmanager
from app.core.event import Event
@ -64,17 +65,17 @@ class SiteStatistic(_PluginBase):
self._scheduler.start()
@staticmethod
def get_command() -> dict:
def get_command() -> List[Dict[str, Any]]:
"""
定义远程控制命令
:return: 命令关键字事件描述附带数据
"""
return {
return [{
"cmd": "/site_statistic",
"event": EventType.SiteStatistic,
"desc": "站点数据统计",
"data": {}
}
}]
def stop_service(self):
pass
@ -181,6 +182,28 @@ class SiteStatistic(_PluginBase):
return site_schema(site_name, url, site_cookie, html_text, session=session, ua=ua, proxy=proxy)
return None
def refresh_by_domain(self, domain: str) -> schemas.Response:
"""
刷新一个站点数据可由API调用
"""
site_info = self.sites.get_indexer(domain)
if site_info:
site_data = self.__refresh_site_data(site_info)
if site_data:
return schemas.Response(
success=True,
message=f"站点 {domain} 刷新成功",
data=site_data.to_dict()
)
return schemas.Response(
success=False,
message=f"站点 {domain} 刷新数据失败,未获取到数据"
)
return schemas.Response(
success=False,
message=f"站点 {domain} 不存在"
)
def __refresh_site_data(self, site_info: CommentedMap) -> Optional[ISiteUserInfo]:
"""
更新单个site 数据信息

View File

@ -6,3 +6,4 @@ from pydantic import BaseModel
class Response(BaseModel):
success: bool
message: Optional[str] = None
data: Optional[dict] = {}

View File

@ -1,4 +1,5 @@
import inspect
from types import FunctionType
from typing import Any, Callable
@ -25,3 +26,10 @@ class ObjectUtils:
parameter_names = parameter_names[1:]
return len(parameter_names)
@staticmethod
def check_method(func: FunctionType) -> bool:
"""
检查函数是否已实现
"""
return func.__code__.co_code != b'd\x01S\x00'