add 插件API
This commit is contained in:
parent
c0ea8097f8
commit
5dd7878e1b
@ -1,6 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, site, message, webhook, subscribe, media, douban, search
|
||||
from app.api.endpoints import login, user, site, message, webhook, subscribe, media, douban, search, plugin
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, tags=["login"])
|
||||
@ -12,3 +12,4 @@ api_router.include_router(subscribe.router, prefix="/subscribe", tags=["subscrib
|
||||
api_router.include_router(media.router, prefix="/media", tags=["media"])
|
||||
api_router.include_router(search.router, prefix="/search", tags=["search"])
|
||||
api_router.include_router(douban.router, prefix="/douban", tags=["douban"])
|
||||
api_router.include_router(plugin.router, prefix="/plugin", tags=["plugin"])
|
||||
|
@ -25,4 +25,4 @@ async def sync_douban(
|
||||
同步豆瓣想看
|
||||
"""
|
||||
background_tasks.add_task(start_douban_chain)
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True, message="任务已启动")
|
||||
|
@ -31,9 +31,9 @@ async def login_access_token(
|
||||
elif not user.is_active:
|
||||
raise HTTPException(status_code=400, detail="用户未启用")
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return {
|
||||
"access_token": security.create_access_token(
|
||||
return schemas.Token(
|
||||
access_token=security.create_access_token(
|
||||
user.id, expires_delta=access_token_expires
|
||||
),
|
||||
"token_type": "bearer",
|
||||
}
|
||||
token_type="bearer",
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ async def user_message(background_tasks: BackgroundTasks, request: Request):
|
||||
form = await request.form()
|
||||
args = request.query_params
|
||||
background_tasks.add_task(start_message_chain, body, form, args)
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
|
25
app/api/endpoints/plugin.py
Normal file
25
app/api/endpoints/plugin.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app import schemas
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.models.user import User
|
||||
from app.db.userauth import get_current_active_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=schemas.Response)
|
||||
@router.post("/")
|
||||
async def run_plugin_method(plugin_id: str, method: str,
|
||||
_: User = Depends(get_current_active_user),
|
||||
*args,
|
||||
**kwargs) -> Any:
|
||||
"""
|
||||
运行插件方法
|
||||
"""
|
||||
return PluginManager().run_plugin_method(pid=plugin_id,
|
||||
method=method,
|
||||
*args,
|
||||
**kwargs)
|
@ -68,8 +68,8 @@ async def cookie_cloud_sync(_: User = Depends(get_current_active_user)) -> Any:
|
||||
"""
|
||||
status, error_msg = CookieCloudChain().process()
|
||||
if not status:
|
||||
return {"success": False, "message": error_msg}
|
||||
return {"success": True, "message": error_msg}
|
||||
schemas.Response(success=True, message=error_msg)
|
||||
return schemas.Response(success=True, message="同步成功!")
|
||||
|
||||
|
||||
@router.get("/cookie", response_model=schemas.Response)
|
||||
@ -94,6 +94,6 @@ async def update_cookie(
|
||||
username=username,
|
||||
password=password)
|
||||
if not status:
|
||||
return {"success": False, "message": msg}
|
||||
return schemas.Response(success=False, message=msg)
|
||||
else:
|
||||
return {"success": True, "message": msg}
|
||||
return schemas.Response(success=True, message=msg)
|
||||
|
@ -44,7 +44,7 @@ async def create_subscribe(
|
||||
新增订阅
|
||||
"""
|
||||
result = SubscribeChain().add(**subscribe_in.dict())
|
||||
return {"success": result}
|
||||
return schemas.Response(success=result)
|
||||
|
||||
|
||||
@router.put("/", response_model=schemas.Subscribe)
|
||||
@ -78,7 +78,7 @@ async def delete_subscribe(
|
||||
删除订阅信息
|
||||
"""
|
||||
Subscribe.delete(db, subscribe_in.id)
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post("/seerr", response_model=schemas.Response)
|
||||
@ -100,12 +100,12 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
|
||||
)
|
||||
notification_type = req_json.get("notification_type")
|
||||
if notification_type not in ["MEDIA_APPROVED", "MEDIA_AUTO_APPROVED"]:
|
||||
return {"success": False, "message": "不支持的通知类型"}
|
||||
return schemas.Response(success=False, message="不支持的通知类型")
|
||||
subject = req_json.get("subject")
|
||||
media_type = MediaType.MOVIE if req_json.get("media", {}).get("media_type") == "movie" else MediaType.TV
|
||||
tmdbId = req_json.get("media", {}).get("tmdbId")
|
||||
if not media_type or not tmdbId or not subject:
|
||||
return {"success": False, "message": "请求参数不正确"}
|
||||
return schemas.Response(success=False, message="请求参数不正确")
|
||||
user_name = req_json.get("request", {}).get("requestedBy_username")
|
||||
# 添加订阅
|
||||
if media_type == MediaType.MOVIE:
|
||||
@ -131,7 +131,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
|
||||
season=season,
|
||||
username=user_name)
|
||||
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/refresh", response_model=schemas.Response)
|
||||
@ -141,7 +141,7 @@ async def refresh_subscribes(
|
||||
刷新所有订阅
|
||||
"""
|
||||
SubscribeChain().refresh()
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/search", response_model=schemas.Response)
|
||||
@ -151,4 +151,4 @@ async def search_subscribes(
|
||||
搜索所有订阅
|
||||
"""
|
||||
SubscribeChain().search(state='R')
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
|
@ -90,7 +90,7 @@ async def delete_user(
|
||||
detail="用户不存在",
|
||||
)
|
||||
user.delete_by_email(db, user_in.email)
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=schemas.User)
|
||||
|
@ -23,9 +23,9 @@ async def webhook_message(background_tasks: BackgroundTasks,
|
||||
Webhook响应
|
||||
"""
|
||||
if token != settings.API_TOKEN:
|
||||
return {"success": False, "message": "token认证不通过"}
|
||||
return schemas.Response(success=False, message="token认证不通过")
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
args = request.query_params
|
||||
background_tasks.add_task(start_webhook_chain, body, form, args)
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
|
@ -340,7 +340,7 @@ async def arr_remove_movie(apikey: str, mid: int, db: Session = Depends(get_db))
|
||||
subscribe = Subscribe.get(db, mid)
|
||||
if subscribe:
|
||||
subscribe.delete(db, mid)
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
@ -660,7 +660,7 @@ async def arr_remove_series(apikey: str, tid: int, db: Session = Depends(get_db)
|
||||
subscribe = Subscribe.get(db, tid)
|
||||
if subscribe:
|
||||
subscribe.delete(db, tid)
|
||||
return {"success": True}
|
||||
return schemas.Response(success=True)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
|
@ -1,9 +1,9 @@
|
||||
from types import FunctionType
|
||||
from typing import Generator, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.module import ModuleHelper
|
||||
from app.log import logger
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
@ -68,16 +68,9 @@ class ModuleManager(metaclass=Singleton):
|
||||
"""
|
||||
获取模块列表
|
||||
"""
|
||||
|
||||
def check_method(func: FunctionType) -> bool:
|
||||
"""
|
||||
检查函数是否已实现
|
||||
"""
|
||||
return func.__code__.co_code != b'd\x01S\x00'
|
||||
|
||||
if not self._running_modules:
|
||||
return []
|
||||
for _, module in self._running_modules.items():
|
||||
if hasattr(module, method) \
|
||||
and check_method(getattr(module, method)):
|
||||
and ObjectUtils.check_method(getattr(module, method)):
|
||||
yield module
|
||||
|
@ -1,9 +1,10 @@
|
||||
import traceback
|
||||
from typing import List, Any
|
||||
from typing import List, Any, Dict
|
||||
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.module import ModuleHelper
|
||||
from app.log import logger
|
||||
from app.utils.object import ObjectUtils
|
||||
from app.utils.singleton import Singleton
|
||||
|
||||
|
||||
@ -24,6 +25,7 @@ class PluginManager(metaclass=Singleton):
|
||||
self.init_config()
|
||||
|
||||
def init_config(self):
|
||||
# 配置管理
|
||||
self.systemconfig = SystemConfigOper()
|
||||
# 停止已有插件
|
||||
self.stop()
|
||||
@ -32,21 +34,7 @@ class PluginManager(metaclass=Singleton):
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
启动
|
||||
"""
|
||||
# 加载插件
|
||||
self.__load_plugins()
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
停止
|
||||
"""
|
||||
# 停止所有插件
|
||||
self.__stop_plugins()
|
||||
|
||||
def __load_plugins(self):
|
||||
"""
|
||||
加载所有插件
|
||||
启动加载插件
|
||||
"""
|
||||
# 扫描插件目录
|
||||
plugins = ModuleHelper.load(
|
||||
@ -59,32 +47,24 @@ class PluginManager(metaclass=Singleton):
|
||||
self._plugins = {}
|
||||
for plugin in plugins:
|
||||
plugin_id = plugin.__name__
|
||||
self._plugins[plugin_id] = plugin
|
||||
# 生成实例
|
||||
self._running_plugins[plugin_id] = plugin()
|
||||
# 初始化配置
|
||||
self.reload_plugin(plugin_id)
|
||||
logger.info(f"Plugin Loaded:{plugin.__name__}")
|
||||
|
||||
def reload_plugin(self, pid: str):
|
||||
"""
|
||||
生效插件配置
|
||||
"""
|
||||
if not pid:
|
||||
return
|
||||
if not self._running_plugins.get(pid):
|
||||
return
|
||||
if hasattr(self._running_plugins[pid], "init_plugin"):
|
||||
try:
|
||||
self._running_plugins[pid].init_plugin(self.get_plugin_config(pid))
|
||||
logger.debug(f"生效插件配置:{pid}")
|
||||
# 存储Class
|
||||
self._plugins[plugin_id] = plugin
|
||||
# 生成实例
|
||||
plugin_obj = plugin()
|
||||
# 生效插件配置
|
||||
plugin_obj.init_plugin(self.get_plugin_config(plugin_id))
|
||||
# 存储运行实例
|
||||
self._running_plugins[plugin_id] = plugin_obj
|
||||
logger.info(f"Plugin Loaded:{plugin_id}")
|
||||
except Exception as err:
|
||||
logger.error(f"加载插件 {pid} 出错:{err} - {traceback.format_exc()}")
|
||||
logger.error(f"加载插件 {plugin_id} 出错:{err} - {traceback.format_exc()}")
|
||||
|
||||
def __stop_plugins(self):
|
||||
def stop(self):
|
||||
"""
|
||||
停止所有插件
|
||||
停止
|
||||
"""
|
||||
# 停止所有插件
|
||||
for plugin in self._running_plugins.values():
|
||||
if hasattr(plugin, "stop"):
|
||||
plugin.stop()
|
||||
@ -105,7 +85,7 @@ class PluginManager(metaclass=Singleton):
|
||||
return False
|
||||
return self.systemconfig.set(self._config_key % pid, conf)
|
||||
|
||||
def get_plugin_commands(self) -> List[dict]:
|
||||
def get_plugin_commands(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取插件命令
|
||||
[{
|
||||
@ -117,8 +97,9 @@ class PluginManager(metaclass=Singleton):
|
||||
"""
|
||||
ret_commands = []
|
||||
for _, plugin in self._running_plugins.items():
|
||||
if hasattr(plugin, "get_command"):
|
||||
ret_commands.append(plugin.get_command())
|
||||
if hasattr(plugin, "get_command") \
|
||||
and ObjectUtils.check_method(plugin.get_command):
|
||||
ret_commands += plugin.get_command()
|
||||
return ret_commands
|
||||
|
||||
def run_plugin_method(self, pid: str, method: str, *args, **kwargs) -> Any:
|
||||
|
@ -1,6 +1,6 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
@ -27,8 +27,6 @@ class _PluginBase(metaclass=ABCMeta):
|
||||
- update_config() 更新配置信息
|
||||
- init_plugin() 生效配置信息
|
||||
- get_data_path() 获取插件数据保存目录
|
||||
- get_command() 获取插件命令,使用消息机制通过远程控制
|
||||
|
||||
"""
|
||||
# 插件名称
|
||||
plugin_name: str = ""
|
||||
@ -48,6 +46,20 @@ class _PluginBase(metaclass=ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_command() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取插件命令
|
||||
[{
|
||||
"cmd": "/xx",
|
||||
"event": EventType.xx,
|
||||
"desc": "xxxx",
|
||||
"data": {}
|
||||
}]
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_service(self):
|
||||
"""
|
||||
|
@ -2,12 +2,13 @@ import traceback
|
||||
from multiprocessing.dummy import Pool as ThreadPool
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from threading import Event
|
||||
from typing import Any
|
||||
from typing import Any, List, Dict
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from ruamel.yaml import CommentedMap
|
||||
|
||||
from app import schemas
|
||||
from app.core.event import EventManager, eventmanager
|
||||
from app.core.config import settings
|
||||
from app.helper.browser import PlaywrightHelper
|
||||
@ -18,6 +19,7 @@ from app.log import logger
|
||||
from app.plugins import _PluginBase
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.site import SiteUtils
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.timer import TimerUtils
|
||||
from app.schemas.types import EventType
|
||||
|
||||
@ -64,17 +66,17 @@ class AutoSignIn(_PluginBase):
|
||||
self._scheduler.start()
|
||||
|
||||
@staticmethod
|
||||
def get_command() -> dict:
|
||||
def get_command() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
定义远程控制命令
|
||||
:return: 命令关键字、事件、描述、附带数据
|
||||
"""
|
||||
return {
|
||||
return [{
|
||||
"cmd": "/site_signin",
|
||||
"event": EventType.SiteSignin,
|
||||
"desc": "站点签到",
|
||||
"data": {}
|
||||
}
|
||||
}]
|
||||
|
||||
@eventmanager.register(EventType.SiteSignin)
|
||||
def sign_in(self, event: Event = None):
|
||||
@ -110,6 +112,23 @@ class AutoSignIn(_PluginBase):
|
||||
logger.error("站点模块加载失败:%s" % str(e))
|
||||
return None
|
||||
|
||||
def signin_by_domain(self, url: str) -> schemas.Response:
|
||||
"""
|
||||
签到一个站点,可由API调用
|
||||
"""
|
||||
domain = StringUtils.get_url_domain(url)
|
||||
site_info = self.sites.get_indexer(domain)
|
||||
if site_info:
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message=f"站点【{url}】不存在"
|
||||
)
|
||||
else:
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message=self.signin_site(site_info)
|
||||
)
|
||||
|
||||
def signin_site(self, site_info: CommentedMap) -> str:
|
||||
"""
|
||||
签到一个站点
|
||||
|
@ -1,12 +1,13 @@
|
||||
from datetime import datetime
|
||||
from multiprocessing.dummy import Pool as ThreadPool
|
||||
from threading import Lock
|
||||
from typing import Optional, Any, List
|
||||
from typing import Optional, Any, List, Dict
|
||||
|
||||
import requests
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from ruamel.yaml import CommentedMap
|
||||
|
||||
from app import schemas
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.event import Event
|
||||
@ -64,17 +65,17 @@ class SiteStatistic(_PluginBase):
|
||||
self._scheduler.start()
|
||||
|
||||
@staticmethod
|
||||
def get_command() -> dict:
|
||||
def get_command() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
定义远程控制命令
|
||||
:return: 命令关键字、事件、描述、附带数据
|
||||
"""
|
||||
return {
|
||||
return [{
|
||||
"cmd": "/site_statistic",
|
||||
"event": EventType.SiteStatistic,
|
||||
"desc": "站点数据统计",
|
||||
"data": {}
|
||||
}
|
||||
}]
|
||||
|
||||
def stop_service(self):
|
||||
pass
|
||||
@ -181,6 +182,28 @@ class SiteStatistic(_PluginBase):
|
||||
return site_schema(site_name, url, site_cookie, html_text, session=session, ua=ua, proxy=proxy)
|
||||
return None
|
||||
|
||||
def refresh_by_domain(self, domain: str) -> schemas.Response:
|
||||
"""
|
||||
刷新一个站点数据,可由API调用
|
||||
"""
|
||||
site_info = self.sites.get_indexer(domain)
|
||||
if site_info:
|
||||
site_data = self.__refresh_site_data(site_info)
|
||||
if site_data:
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message=f"站点 {domain} 刷新成功",
|
||||
data=site_data.to_dict()
|
||||
)
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"站点 {domain} 刷新数据失败,未获取到数据"
|
||||
)
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"站点 {domain} 不存在"
|
||||
)
|
||||
|
||||
def __refresh_site_data(self, site_info: CommentedMap) -> Optional[ISiteUserInfo]:
|
||||
"""
|
||||
更新单个site 数据信息
|
||||
|
@ -6,3 +6,4 @@ from pydantic import BaseModel
|
||||
class Response(BaseModel):
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
data: Optional[dict] = {}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
from types import FunctionType
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@ -25,3 +26,10 @@ class ObjectUtils:
|
||||
parameter_names = parameter_names[1:]
|
||||
|
||||
return len(parameter_names)
|
||||
|
||||
@staticmethod
|
||||
def check_method(func: FunctionType) -> bool:
|
||||
"""
|
||||
检查函数是否已实现
|
||||
"""
|
||||
return func.__code__.co_code != b'd\x01S\x00'
|
||||
|
Loading…
x
Reference in New Issue
Block a user