From 7faaaf3dcd7c3c6309b1f67ac79fbe7829ac0fd5 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Sat, 11 Nov 2023 14:14:09 +0800 Subject: [PATCH] fix bug --- app/api/endpoints/download.py | 9 ++++++++- app/api/endpoints/media.py | 18 ++++++++++-------- app/chain/media.py | 6 +++--- tests/test_recognize.py | 2 +- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/app/api/endpoints/download.py b/app/api/endpoints/download.py index b2267074..0bed5384 100644 --- a/app/api/endpoints/download.py +++ b/app/api/endpoints/download.py @@ -62,11 +62,18 @@ def exists(media_in: schemas.MediaInfo, # 媒体信息 meta = MetaInfo(title=media_in.title) mtype = MediaType(media_in.type) if media_in.type else None + if mtype: + meta.type = mtype + if media_in.season: + meta.begin_season = media_in.season + meta.type = MediaType.TV + if media_in.year: + meta.year = media_in.year if media_in.tmdb_id or media_in.douban_id: mediainfo = MediaChain().recognize_media(meta=meta, mtype=mtype, tmdbid=media_in.tmdb_id, doubanid=media_in.douban_id) else: - mediainfo = MediaChain().recognize_by_title(title=f"{media_in.title} {media_in.year}") + mediainfo = MediaChain().recognize_by_meta(metainfo=meta) # 查询缺失信息 if not mediainfo or not mediainfo.tmdb_id: raise HTTPException(status_code=404, detail="媒体信息不存在") diff --git a/app/api/endpoints/media.py b/app/api/endpoints/media.py index 682a485e..3185bf16 100644 --- a/app/api/endpoints/media.py +++ b/app/api/endpoints/media.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from app import schemas from app.chain.media import MediaChain from app.core.config import settings +from app.core.context import Context from app.core.metainfo import MetaInfo from app.core.security import verify_token, verify_uri_token from app.db import get_db @@ -23,9 +24,10 @@ def recognize(title: str, 根据标题、副标题识别媒体信息 """ # 识别媒体信息 - context = MediaChain().recognize_by_title(title=title, subtitle=subtitle) - if context: - return context.to_dict() + metainfo = MetaInfo(title, subtitle) + mediainfo = MediaChain().recognize_by_meta(metainfo) + if mediainfo: + return Context(meta_info=metainfo, media_info=mediainfo).to_dict() return schemas.Context() @@ -41,8 +43,8 @@ def recognize2(title: str, @router.get("/recognize_file", summary="识别媒体信息(文件)", response_model=schemas.Context) -def recognize(path: str, - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +def recognize_file(path: str, + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据文件路径识别媒体信息 """ @@ -54,13 +56,13 @@ def recognize(path: str, @router.get("/recognize_file2", summary="识别文件媒体信息(API_TOKEN)", response_model=schemas.Context) -def recognize2(path: str, - _: str = Depends(verify_uri_token)) -> Any: +def recognize_file2(path: str, + _: str = Depends(verify_uri_token)) -> Any: """ 根据文件路径识别媒体信息 API_TOKEN认证(?token=xxx) """ # 识别媒体信息 - return recognize(path) + return recognize_file(path) @router.get("/search", summary="搜索媒体信息", response_model=List[schemas.MediaInfo]) diff --git a/app/chain/media.py b/app/chain/media.py index 1d7c821f..649f5b36 100644 --- a/app/chain/media.py +++ b/app/chain/media.py @@ -26,13 +26,13 @@ class MediaChain(ChainBase, metaclass=Singleton): # 临时识别结果 {title, name, year, season, episode} recognize_temp: Optional[dict] = None - def recognize_by_title(self, title: str, subtitle: str = None) -> Optional[MediaInfo]: + def recognize_by_meta(self, metainfo: MetaBase) -> Optional[MediaInfo]: """ 根据主副标题识别媒体信息 """ + title = metainfo.title + subtitle = metainfo.subtitle logger.info(f'开始识别媒体信息,标题:{title},副标题:{subtitle} ...') - # 识别元数据 - metainfo = MetaInfo(title, subtitle) # 识别媒体信息 mediainfo: MediaInfo = self.recognize_media(meta=metainfo) if not mediainfo: diff --git a/tests/test_recognize.py b/tests/test_recognize.py index 1d7a9317..5b5e2473 100644 --- a/tests/test_recognize.py +++ b/tests/test_recognize.py @@ -15,7 +15,7 @@ class RecognizeTest(TestCase): pass def test_recognize(self): - media_info = MediaChain().recognize_by_title(title="我和我的祖国 2019") + media_info = MediaChain().recognize_by_meta(MetaInfo("我和我的祖国 2019")) self.assertEqual(media_info.tmdb_id, 612845) exists = DownloadChain().get_no_exists_info(MetaInfo("我和我的祖国 2019"), media_info) self.assertTrue(exists[0])