diff --git a/app/core/plugin.py b/app/core/plugin.py index ac352e3e..31e41f2b 100644 --- a/app/core/plugin.py +++ b/app/core/plugin.py @@ -2,11 +2,12 @@ import concurrent import concurrent.futures import importlib.util import inspect +import os import threading import time import traceback from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -21,6 +22,7 @@ from app.helper.plugin import PluginHelper from app.helper.sites import SitesHelper from app.log import logger from app.schemas.types import SystemConfigKey +from app.utils.crypto import RSAUtils from app.utils.object import ObjectUtils from app.utils.singleton import Singleton from app.utils.string import StringUtils @@ -159,11 +161,12 @@ class PluginManager(metaclass=Singleton): if pid and plugin_id != pid: continue try: - # 如果插件具有认证级别且当前认证级别不足,则不进行实例化 - if hasattr(plugin, "auth_level"): - plugin.auth_level = plugin.auth_level - if self.siteshelper.auth_level < plugin.auth_level: - continue + # 判断插件是否满足认证要求,如不满足则不进行实例化 + if not self.__set_and_check_auth_level(plugin=plugin): + # 如果是插件热更新实例,这里则进行替换 + if plugin_id in self._plugins: + self._plugins[plugin_id] = plugin + continue # 存储Class self._plugins[plugin_id] = plugin # 未安装的不加载 @@ -601,11 +604,12 @@ class PluginManager(metaclass=Singleton): if plugin_obj and hasattr(plugin_obj, "get_page"): if ObjectUtils.check_method(plugin_obj.get_page): plugin.has_page = True + # 公钥 + if plugin_info.get("key"): + plugin.plugin_public_key = plugin_info.get("key") # 权限 - if plugin_info.get("level"): - plugin.auth_level = plugin_info.get("level") - if self.siteshelper.auth_level < plugin.auth_level: - continue + if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info): + continue # 名称 if plugin_info.get("name"): plugin.plugin_name = plugin_info.get("name") @@ -708,11 +712,12 @@ class PluginManager(metaclass=Singleton): plugin.has_page = True else: plugin.has_page = False + # 公钥 + if hasattr(plugin_class, "plugin_public_key"): + plugin.plugin_public_key = plugin_class.plugin_public_key # 权限 - if hasattr(plugin_class, "auth_level"): - plugin.auth_level = plugin_class.auth_level - if self.siteshelper.auth_level < plugin.auth_level: - continue + if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_class): + continue # 名称 if hasattr(plugin_class, "plugin_name"): plugin.plugin_name = plugin_class.plugin_name @@ -762,3 +767,57 @@ class PluginManager(metaclass=Singleton): except Exception as e: logger.debug(f"获取插件是否在本地包中存在失败,{e}") return False + + def __set_and_check_auth_level(self, plugin: Union[schemas.Plugin, Type[Any]], + source: Optional[Union[dict, Type[Any]]] = None) -> bool: + """ + 设置并检查插件的认证级别 + :param plugin: 插件对象或包含 auth_level 属性的对象 + :param source: 可选的字典对象或类对象,可能包含 "level" 或 "auth_level" 键 + :return: 如果插件的认证级别有效且当前环境的认证级别满足要求,返回 True,否则返回 False + """ + # 检查并赋值 source 中的 level 或 auth_level + if source: + if isinstance(source, dict) and "level" in source: + plugin.auth_level = source.get("level") + elif hasattr(source, "auth_level"): + plugin.auth_level = source.auth_level + # 如果 source 为空且 plugin 本身没有 auth_level,直接返回 True + elif not hasattr(plugin, "auth_level"): + return True + + # auth_level 级别说明 + # 1 - 所有用户可见 + # 2 - 站点认证用户可见 + # 3 - 站点&密钥认证可见 + # 99 - 站点&特殊密钥认证可见 + # 如果当前站点认证级别大于 1 且插件级别为 99,并存在插件公钥,说明为特殊密钥认证,通过密钥匹配进行认证 + if self.siteshelper.auth_level > 1 and plugin.auth_level == 99 and hasattr(plugin, "plugin_public_key"): + plugin_id = plugin.id if isinstance(plugin, schemas.Plugin) else plugin.__name__ + public_key = plugin.plugin_public_key + if public_key: + private_key = PluginManager.__get_plugin_private_key(plugin_id) + verify = RSAUtils.verify_rsa_keys(public_key=public_key, private_key=private_key) + return verify + # 如果当前站点认证级别小于插件级别,则返回 False + if self.siteshelper.auth_level < plugin.auth_level: + return False + return True + + @staticmethod + def __get_plugin_private_key(plugin_id: str) -> Optional[str]: + """ + 根据插件标识获取对应的私钥 + :param plugin_id: 插件标识 + :return: 对应的插件私钥,如果未找到则返回 None + """ + try: + # 将插件标识转换为大写并构建环境变量名称 + env_var_name = f"PLUGIN_{plugin_id.upper()}_PRIVATE_KEY" + private_key = os.environ.get(env_var_name) + if private_key is None: + logger.debug(f"环境变量 {env_var_name} 未找到。") + return private_key + except Exception as e: + logger.debug(f"获取插件 {plugin_id} 的私钥时发生错误:{e}") + return None diff --git a/app/schemas/plugin.py b/app/schemas/plugin.py index c38fd074..742ff1a7 100644 --- a/app/schemas/plugin.py +++ b/app/schemas/plugin.py @@ -46,6 +46,8 @@ class Plugin(BaseModel): history: Optional[dict] = {} # 添加时间,值越小表示越靠后发布 add_time: Optional[int] = 0 + # 插件公钥 + plugin_public_key: Optional[str] = None class PluginDashboard(Plugin): diff --git a/app/utils/crypto.py b/app/utils/crypto.py new file mode 100644 index 00000000..b1b7dc91 --- /dev/null +++ b/app/utils/crypto.py @@ -0,0 +1,91 @@ +import base64 + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import rsa, padding + + +class RSAUtils: + + @staticmethod + def generate_rsa_key_pair() -> (str, str): + """ + 生成RSA密钥对并返回Base64编码的公钥和私钥(DER格式) + + :return: Tuple containing Base64 encoded public key and private key + """ + # 生成RSA密钥对 + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + public_key = private_key.public_key() + + # 导出私钥为DER格式 + private_key_der = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ) + + # 导出公钥为DER格式 + public_key_der = public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + # 将DER格式的密钥编码为Base64 + private_key_b64 = base64.b64encode(private_key_der).decode('utf-8') + public_key_b64 = base64.b64encode(public_key_der).decode('utf-8') + + return private_key_b64, public_key_b64 + + @staticmethod + def verify_rsa_keys(private_key: str, public_key: str) -> bool: + """ + 使用 RSA 验证公钥和私钥是否匹配 + + :param private_key: 私钥字符串 (Base64 编码,无标识符) + :param public_key: 公钥字符串 (Base64 编码,无标识符) + :return: 如果匹配则返回 True,否则返回 False + """ + if not private_key or not public_key: + return False + + try: + # 解码 Base64 编码的公钥和私钥 + public_key_bytes = base64.b64decode(public_key) + private_key_bytes = base64.b64decode(private_key) + + # 加载公钥 + public_key = serialization.load_der_public_key(public_key_bytes, backend=default_backend()) + + # 加载私钥 + private_key = serialization.load_der_private_key(private_key_bytes, password=None, + backend=default_backend()) + + # 测试加解密 + message = b'test' + encrypted_message = public_key.encrypt( + message, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None + ) + ) + + decrypted_message = private_key.decrypt( + encrypted_message, + padding.OAEP( + mgf=padding.MGF1(algorithm=hashes.SHA256()), + algorithm=hashes.SHA256(), + label=None + ) + ) + + return message == decrypted_message + except Exception as e: + print(f"RSA 密钥验证失败: {e}") + return False