diff --git a/app/api/endpoints/media.py b/app/api/endpoints/media.py index f9593af4..20a48d8d 100644 --- a/app/api/endpoints/media.py +++ b/app/api/endpoints/media.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Any +from typing import List, Any, Union from fastapi import APIRouter, Depends @@ -72,15 +72,29 @@ def search(title: str, """ 模糊搜索媒体/人物信息列表 media:媒体信息,person:人物信息 """ + def __get_source(obj: Union[dict, schemas.MediaPerson]): + """ + 获取对象属性 + """ + if isinstance(obj, dict): + return obj.get("source") + return obj.source + + result = [] if type == "media": _, medias = MediaChain().search(title=title) if medias: - return [media.to_dict() for media in medias[(page - 1) * count: page * count]] + result = [media.to_dict() for media in medias] else: - persons = MediaChain().search_persons(name=title) - if persons: - return persons[(page - 1) * count: page * count] - return [] + result = MediaChain().search_persons(name=title) + if result: + # 按设置的顺序对结果进行排序 + setting_order = settings.SEARCH_SOURCE.split(',') or [] + sort_order = {} + for index, source in enumerate(setting_order): + sort_order[source] = index + result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4)) + return result[(page - 1) * count:page * count] @router.get("/scrape", summary="刮削媒体信息", response_model=schemas.Response)