diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index 1dc6b0a2..4e7031d6 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -30,35 +30,33 @@ async def login_access_token( 获取认证Token """ # 检查数据库 - user = User.authenticate( + success, user = User.authenticate( db=db, name=form_data.username, password=form_data.password, otp_password=otp_password ) - if not user: - # 请求协助认证 - logger.warn(f"登录用户 {form_data.username} 本地用户名或密码不匹配,尝试辅助认证 ...") - token = UserChain().user_authenticate(form_data.username, form_data.password) - if not token: - logger.warn(f"用户 {form_data.username} 登录失败!") - raise HTTPException(status_code=401, detail="用户名、密码、二次校验不正确") - else: - logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...") - # 加入用户信息表 - user = User.get_by_name(db=db, name=form_data.username) - if not user: - logger.info(f"用户不存在,创建用户: {form_data.username}") + if not success: + # 认证不成功 + if not user: + # 未找到用户,请求协助认证 + logger.warn(f"登录用户 {form_data.username} 本地不存在,尝试辅助认证 ...") + token = UserChain().user_authenticate(form_data.username, form_data.password) + if not token: + logger.warn(f"用户 {form_data.username} 登录失败!") + raise HTTPException(status_code=401, detail="用户名、密码、二次校验码不正确") + else: + logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...") + # 加入用户信息表 + logger.info(f"创建用户: {form_data.username}") user = User(name=form_data.username, is_active=True, is_superuser=False, hashed_password=get_password_hash(token)) user.create(db) - else: - # 辅助验证用户若未启用,则禁止登录 - if not user.is_active: - raise HTTPException(status_code=403, detail="用户未启用") - # 普通用户权限 - user.is_superuser = False - elif not user.is_active: + else: + # 用户存在,但认证失败 + logger.warn(f"用户 {user.name} 登录失败!") + raise HTTPException(status_code=401, detail="用户名、密码或二次校验码不正确") + elif user and not user.is_active: raise HTTPException(status_code=403, detail="用户未启用") logger.info(f"用户 {user.name} 登录成功!") return schemas.Token( diff --git a/app/api/endpoints/plugin.py b/app/api/endpoints/plugin.py index 6d374dae..5573b403 100644 --- a/app/api/endpoints/plugin.py +++ b/app/api/endpoints/plugin.py @@ -14,16 +14,16 @@ router = APIRouter() @router.get("/", summary="所有插件", response_model=List[schemas.Plugin]) -def all_plugins(_: schemas.TokenPayload = Depends(verify_token), state: str = "all") -> Any: +def all_plugins(_: schemas.TokenPayload = Depends(verify_token), state: str = "all") -> List[schemas.Plugin]: """ 查询所有插件清单,包括本地插件和在线插件,插件状态:installed, market, all """ # 本地插件 local_plugins = PluginManager().get_local_plugins() # 已安装插件 - installed_plugins = [plugin for plugin in local_plugins if plugin.get("installed")] + installed_plugins = [plugin for plugin in local_plugins if plugin.installed] # 未安装的本地插件 - not_installed_plugins = [plugin for plugin in local_plugins if not plugin.get("installed")] + not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed] if state == "installed": return installed_plugins @@ -39,17 +39,17 @@ def all_plugins(_: schemas.TokenPayload = Depends(verify_token), state: str = "a # 插件市场插件清单 market_plugins = [] # 已安装插件IDS - _installed_ids = [plugin["id"] for plugin in installed_plugins] + _installed_ids = [plugin.id for plugin in installed_plugins] # 未安装的线上插件或者有更新的插件 for plugin in online_plugins: - if plugin["id"] not in _installed_ids: + if plugin.id not in _installed_ids: market_plugins.append(plugin) - elif plugin.get("has_update"): + elif plugin.has_update: market_plugins.append(plugin) # 未安装的本地插件,且不在线上插件中 - _plugin_ids = [plugin["id"] for plugin in market_plugins] + _plugin_ids = [plugin.id for plugin in market_plugins] for plugin in not_installed_plugins: - if plugin["id"] not in _plugin_ids: + if plugin.id not in _plugin_ids: market_plugins.append(plugin) # 返回插件清单 if state == "market": diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 5fae8602..85ae3036 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -17,8 +17,8 @@ router = APIRouter() @router.get("/", summary="所有用户", response_model=List[schemas.User]) def read_users( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_active_superuser), ) -> Any: """ 查询用户列表 @@ -29,10 +29,10 @@ def read_users( @router.post("/", summary="新增用户", response_model=schemas.Response) def create_user( - *, - db: Session = Depends(get_db), - user_in: schemas.UserCreate, - current_user: User = Depends(get_current_active_superuser), + *, + db: Session = Depends(get_db), + user_in: schemas.UserCreate, + current_user: User = Depends(get_current_active_superuser), ) -> Any: """ 新增用户 @@ -51,10 +51,10 @@ def create_user( @router.put("/", summary="更新用户", response_model=schemas.Response) def update_user( - *, - db: Session = Depends(get_db), - user_in: schemas.UserCreate, - _: User = Depends(get_current_active_superuser), + *, + db: Session = Depends(get_db), + user_in: schemas.UserCreate, + _: User = Depends(get_current_active_superuser), ) -> Any: """ 更新用户 @@ -64,7 +64,8 @@ def update_user( # 正则表达式匹配密码包含字母、数字、特殊字符中的至少两项 pattern = r'^(?![a-zA-Z]+$)(?!\d+$)(?![^\da-zA-Z\s]+$).{6,50}$' if not re.match(pattern, user_info.get("password")): - return schemas.Response(success=False, message="密码需要同时包含字母、数字、特殊字符中的至少两项,且长度大于6位") + return schemas.Response(success=False, + message="密码需要同时包含字母、数字、特殊字符中的至少两项,且长度大于6位") user_info["hashed_password"] = get_password_hash(user_info["password"]) user_info.pop("password") user = User.get_by_name(db, name=user_info["name"]) @@ -76,7 +77,7 @@ def update_user( @router.get("/current", summary="当前登录用户信息", response_model=schemas.User) def read_current_user( - current_user: User = Depends(get_current_active_user) + current_user: User = Depends(get_current_active_user) ) -> Any: """ 当前登录用户信息 @@ -102,12 +103,51 @@ async def upload_avatar(user_id: int, db: Session = Depends(get_db), return schemas.Response(success=True, message=file.filename) +@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response) +def otp_generate( + current_user: User = Depends(get_current_active_user) +) -> Any: + secret, uri = OtpUtils.generate_secret_key(current_user.name) + return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri}) + + +@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response) +def otp_judge( + data: dict, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_active_user) +) -> Any: + uri = data.get("uri") + otp_password = data.get("otpPassword") + if not OtpUtils.is_legal(uri, otp_password): + return schemas.Response(success=False, message="验证码错误") + current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) + return schemas.Response(success=True) + + +@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response) +def otp_disable( + db: Session = Depends(get_db), + current_user: User = Depends(get_current_active_user) +) -> Any: + current_user.update_otp_by_name(db, current_user.name, False, "") + return schemas.Response(success=True) + + +@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response) +def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any: + user: User = User.get_by_name(db, userid) + if not user: + return schemas.Response(success=False, message="用户不存在") + return schemas.Response(success=user.is_otp) + + @router.delete("/{user_name}", summary="删除用户", response_model=schemas.Response) def delete_user( - *, - db: Session = Depends(get_db), - user_name: str, - current_user: User = Depends(get_current_active_superuser), + *, + db: Session = Depends(get_db), + user_name: str, + current_user: User = Depends(get_current_active_superuser), ) -> Any: """ 删除用户 @@ -121,9 +161,9 @@ def delete_user( @router.get("/{user_id}", summary="用户详情", response_model=schemas.User) def read_user_by_id( - user_id: int, - current_user: User = Depends(get_current_active_user), - db: Session = Depends(get_db), + user_id: int, + current_user: User = Depends(get_current_active_user), + db: Session = Depends(get_db), ) -> Any: """ 查询用户详情 @@ -142,34 +182,3 @@ def read_user_by_id( detail="用户权限不足" ) return user - - -@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response) -def otp_generate( - current_user: User = Depends(get_current_active_user) -) -> Any: - secret, uri = OtpUtils.generate_secret_key(current_user.name) - return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri}) - - -@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response) -def otp_judge( - data: dict, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user) -) -> Any: - uri = data.get("uri") - otp_password = data.get("otpPassword") - if not OtpUtils.is_legal(uri, otp_password): - return schemas.Response(success=False, message="验证码错误") - current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) - return schemas.Response(success=True) - - -@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response) -def otp_disable( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user) -) -> Any: - current_user.update_otp_by_name(db, current_user.name, False, "") - return schemas.Response(success=True) diff --git a/app/core/plugin.py b/app/core/plugin.py index 17db7613..607d525b 100644 --- a/app/core/plugin.py +++ b/app/core/plugin.py @@ -3,6 +3,7 @@ import concurrent.futures import traceback from typing import List, Any, Dict, Tuple, Optional +from app import schemas from app.core.config import settings from app.core.event import eventmanager from app.db.systemconfig_oper import SystemConfigOper @@ -130,16 +131,16 @@ class PluginManager(metaclass=Singleton): # 支持更新的插件自动更新 for plugin in online_plugins: # 只处理已安装的插件 - if plugin.get("id") in install_plugins and not self.is_plugin_exists(plugin.get("id")): + if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id): # 下载安装 - state, msg = self.pluginhelper.install(pid=plugin.get("id"), - repo_url=plugin.get("repo_url")) + state, msg = self.pluginhelper.install(pid=plugin.id, + repo_url=plugin.repo_url) # 安装失败 if not state: logger.error( - f"插件 {plugin.get('plugin_name')} v{plugin.get('plugin_version')} 安装失败:{msg}") + f"插件 {plugin.plugin_name} v{plugin.plugin_version} 安装失败:{msg}") continue - logger.info(f"插件 {plugin.get('plugin_name')} 安装成功,版本:{plugin.get('plugin_version')}") + logger.info(f"插件 {plugin.plugin_name} 安装成功,版本:{plugin.plugin_version}") logger.info("第三方插件安装完成") def get_plugin_config(self, pid: str) -> dict: @@ -204,7 +205,10 @@ class PluginManager(metaclass=Singleton): for _, plugin in self._running_plugins.items(): if hasattr(plugin, "get_command") \ and ObjectUtils.check_method(plugin.get_command): - ret_commands += plugin.get_command() or [] + try: + ret_commands += plugin.get_command() or [] + except Exception as e: + logger.error(f"获取插件命令出错:{str(e)}") return ret_commands def get_plugin_apis(self) -> List[Dict[str, Any]]: @@ -222,10 +226,13 @@ class PluginManager(metaclass=Singleton): for pid, plugin in self._running_plugins.items(): if hasattr(plugin, "get_api") \ and ObjectUtils.check_method(plugin.get_api): - apis = plugin.get_api() or [] - for api in apis: - api["path"] = f"/{pid}{api['path']}" - ret_apis.extend(apis) + try: + apis = plugin.get_api() or [] + for api in apis: + api["path"] = f"/{pid}{api['path']}" + ret_apis.extend(apis) + except Exception as e: + logger.error(f"获取插件 {pid} API出错:{str(e)}") return ret_apis def get_plugin_services(self) -> List[Dict[str, Any]]: @@ -243,9 +250,12 @@ class PluginManager(metaclass=Singleton): for pid, plugin in self._running_plugins.items(): if hasattr(plugin, "get_service") \ and ObjectUtils.check_method(plugin.get_service): - services = plugin.get_service() - if services: - ret_services.extend(services) + try: + services = plugin.get_service() + if services: + ret_services.extend(services) + except Exception as e: + logger.error(f"获取插件 {pid} 服务出错:{str(e)}") return ret_services def get_plugin_attr(self, pid: str, attr: str) -> Any: @@ -280,11 +290,11 @@ class PluginManager(metaclass=Singleton): """ return list(self._running_plugins.keys()) - def get_online_plugins(self) -> List[dict]: + def get_online_plugins(self) -> List[schemas.Plugin]: """ 获取所有在线插件信息 """ - def __get_plugin_info(market: str) -> Optional[List[dict]]: + def __get_plugin_info(market: str) -> Optional[List[schemas.Plugin]]: """ 获取插件信息 """ @@ -293,27 +303,27 @@ class PluginManager(metaclass=Singleton): logger.warn(f"获取插件库失败:{market}") return ret_plugins = [] - for pid, plugin in online_plugins.items(): + for pid, plugin_info in online_plugins.items(): # 运行状插件 plugin_obj = self._running_plugins.get(pid) # 非运行态插件 plugin_static = self._plugins.get(pid) # 基本属性 - conf = {} + plugin = schemas.Plugin() # ID - conf.update({"id": pid}) + plugin.id = pid # 安装状态 if pid in installed_apps and plugin_static: - conf.update({"installed": True}) + plugin.installed = True else: - conf.update({"installed": False}) + plugin.installed = False # 是否有新版本 - conf.update({"has_update": False}) + plugin.has_update = False if plugin_static: installed_version = getattr(plugin_static, "plugin_version") - if StringUtils.compare_version(installed_version, plugin.get("version")) < 0: + if StringUtils.compare_version(installed_version, plugin_info.get("version")) < 0: # 需要更新 - conf.update({"has_update": True}) + plugin.has_update = True # 运行状态 if plugin_obj and hasattr(plugin_obj, "get_state"): try: @@ -321,40 +331,40 @@ class PluginManager(metaclass=Singleton): except Exception as e: logger.error(f"获取插件 {pid} 状态出错:{str(e)}") state = False - conf.update({"state": state}) + plugin.state = state else: - conf.update({"state": False}) + plugin.state = False # 是否有详情页面 - conf.update({"has_page": False}) + plugin.has_page = False if plugin_obj and hasattr(plugin_obj, "get_page"): if ObjectUtils.check_method(plugin_obj.get_page): - conf.update({"has_page": True}) + plugin.has_page = True # 权限 - if plugin.get("level"): - conf.update({"auth_level": plugin.get("level")}) - if self.siteshelper.auth_level < plugin.get("level"): + if plugin_info.get("level"): + plugin.auth_level = plugin_info.get("level") + if self.siteshelper.auth_level < plugin.auth_level: continue # 名称 - if plugin.get("name"): - conf.update({"plugin_name": plugin.get("name")}) + if plugin_info.get("name"): + plugin.plugin_name = plugin_info.get("name") # 描述 - if plugin.get("description"): - conf.update({"plugin_desc": plugin.get("description")}) + if plugin_info.get("description"): + plugin.plugin_desc = plugin_info.get("description") # 版本 - if plugin.get("version"): - conf.update({"plugin_version": plugin.get("version")}) + if plugin_info.get("version"): + plugin.plugin_version = plugin_info.get("version") # 图标 - if plugin.get("icon"): - conf.update({"plugin_icon": plugin.get("icon")}) + if plugin_info.get("icon"): + plugin.plugin_icon = plugin_info.get("icon") # 作者 - if plugin.get("author"): - conf.update({"plugin_author": plugin.get("author")}) + if plugin_info.get("author"): + plugin.plugin_author = plugin_info.get("author") # 仓库链接 - conf.update({"repo_url": market}) + plugin.repo_url = market # 本地标志 - conf.update({"is_local": False}) + plugin.is_local = False # 汇总 - ret_plugins.append(conf) + ret_plugins.append(plugin) return ret_plugins @@ -375,39 +385,39 @@ class PluginManager(metaclass=Singleton): all_plugins.extend(plugins) # 所有插件按repo在设置中的顺序排序 all_plugins.sort( - key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.get("repo_url")) if x.get("repo_url") else 0 + key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0 ) # 按插件ID和版本号去重,相同插件以前面的为准 result = [] _dup = [] for p in all_plugins: - key = f"{p.get('id')}v{p.get('plugin_version')}" + key = f"{p.id}v{p.plugin_version}" if key not in _dup: _dup.append(key) result.append(p) logger.info(f"共获取到 {len(result)} 个第三方插件") return result - def get_local_plugins(self) -> List[dict]: + def get_local_plugins(self) -> List[schemas.Plugin]: """ 获取所有本地已下载的插件信息 """ # 返回值 - all_confs = [] + plugins = [] # 已安装插件 installed_apps = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or [] - for pid, plugin in self._plugins.items(): + for pid, plugin_class in self._plugins.items(): # 运行状插件 plugin_obj = self._running_plugins.get(pid) # 基本属性 - conf = {} + plugin = schemas.Plugin() # ID - conf.update({"id": pid}) + plugin.id = pid # 安装状态 if pid in installed_apps: - conf.update({"installed": True}) + plugin.installed = True else: - conf.update({"installed": False}) + plugin.installed = False # 运行状态 if plugin_obj and hasattr(plugin_obj, "get_state"): try: @@ -415,45 +425,45 @@ class PluginManager(metaclass=Singleton): except Exception as e: logger.error(f"获取插件 {pid} 状态出错:{str(e)}") state = False - conf.update({"state": state}) + plugin.state = state else: - conf.update({"state": False}) + plugin.state = False # 是否有详情页面 - if hasattr(plugin, "get_page"): - if ObjectUtils.check_method(plugin.get_page): - conf.update({"has_page": True}) + if hasattr(plugin_class, "get_page"): + if ObjectUtils.check_method(plugin_class.get_page): + plugin.has_page = True else: - conf.update({"has_page": False}) + plugin.has_page = False # 权限 - if hasattr(plugin, "auth_level"): - conf.update({"auth_level": plugin.auth_level}) + if hasattr(plugin_class, "auth_level"): + plugin.auth_level = plugin_class.auth_level if self.siteshelper.auth_level < plugin.auth_level: continue # 名称 - if hasattr(plugin, "plugin_name"): - conf.update({"plugin_name": plugin.plugin_name}) + if hasattr(plugin_class, "plugin_name"): + plugin.plugin_name = plugin_class.plugin_name # 描述 - if hasattr(plugin, "plugin_desc"): - conf.update({"plugin_desc": plugin.plugin_desc}) + if hasattr(plugin_class, "plugin_desc"): + plugin.plugin_desc = plugin_class.plugin_desc # 版本 - if hasattr(plugin, "plugin_version"): - conf.update({"plugin_version": plugin.plugin_version}) + if hasattr(plugin_class, "plugin_version"): + plugin.plugin_version = plugin_class.plugin_version # 图标 - if hasattr(plugin, "plugin_icon"): - conf.update({"plugin_icon": plugin.plugin_icon}) + if hasattr(plugin_class, "plugin_icon"): + plugin.plugin_icon = plugin_class.plugin_icon # 作者 - if hasattr(plugin, "plugin_author"): - conf.update({"plugin_author": plugin.plugin_author}) + if hasattr(plugin_class, "plugin_author"): + plugin.plugin_author = plugin_class.plugin_author # 作者链接 - if hasattr(plugin, "author_url"): - conf.update({"author_url": plugin.author_url}) + if hasattr(plugin_class, "author_url"): + plugin.author_url = plugin_class.author_url # 是否需要更新 - conf.update({"has_update": False}) + plugin.has_update = False # 本地标志 - conf.update({"is_local": True}) + plugin.is_local = True # 汇总 - all_confs.append(conf) - return all_confs + plugins.append(plugin) + return plugins @staticmethod def is_plugin_exists(pid: str) -> bool: diff --git a/app/db/models/user.py b/app/db/models/user.py index 527befe0..742d534c 100644 --- a/app/db/models/user.py +++ b/app/db/models/user.py @@ -1,10 +1,14 @@ +from typing import Tuple, Optional + from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session from app.core.security import verify_password from app.db import db_query, db_update, Base +from app.schemas import User from app.utils.otp import OtpUtils + class User(Base): """ 用户表 @@ -30,16 +34,16 @@ class User(Base): @staticmethod @db_query - def authenticate(db: Session, name: str, password: str, otp_password: str): + def authenticate(db: Session, name: str, password: str, otp_password: str) -> Tuple[bool, Optional[User]]: user = db.query(User).filter(User.name == name).first() if not user: - return None + return False, None if not verify_password(password, str(user.hashed_password)): - return None + return False, user if user.is_otp: if not otp_password or not OtpUtils.check(user.otp_secret, otp_password): - return None - return user + return False, user + return True, user @staticmethod @db_query diff --git a/app/helper/torrent.py b/app/helper/torrent.py index c70244b3..deefd912 100644 --- a/app/helper/torrent.py +++ b/app/helper/torrent.py @@ -326,7 +326,10 @@ class TorrentHelper(metaclass=Singleton): return True # 匹配内容 - content = f"{torrent_info.title} {torrent_info.description} {' '.join(torrent_info.labels or [])}" + content = (f"{torrent_info.title} " + f"{torrent_info.description} " + f"{' '.join(torrent_info.labels or [])} " + f"{torrent_info.volume_factor}") # 最少做种人数 min_seeders = filter_rule.get("min_seeders")