diff --git a/app/api/endpoints/douban.py b/app/api/endpoints/douban.py index 9459e91a..6128e5bf 100644 --- a/app/api/endpoints/douban.py +++ b/app/api/endpoints/douban.py @@ -1,14 +1,12 @@ from typing import List, Any from fastapi import APIRouter, Depends, Response -from sqlalchemy.orm import Session from app import schemas from app.chain.douban import DoubanChain from app.core.config import settings from app.core.context import MediaInfo from app.core.security import verify_token -from app.db import get_db from app.schemas import MediaType from app.utils.http import RequestUtils @@ -32,13 +30,12 @@ def douban_img(imgurl: str) -> Any: @router.get("/recognize/{doubanid}", summary="豆瓣ID识别", response_model=schemas.Context) def recognize_doubanid(doubanid: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据豆瓣ID识别媒体信息 """ # 识别媒体信息 - context = DoubanChain(db).recognize_by_doubanid(doubanid=doubanid) + context = DoubanChain().recognize_by_doubanid(doubanid=doubanid) if context: return context.to_dict() else: @@ -48,12 +45,11 @@ def recognize_doubanid(doubanid: str, @router.get("/showing", summary="豆瓣正在热映", response_model=List[schemas.MediaInfo]) def movie_showing(page: int = 1, count: int = 30, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览豆瓣正在热映 """ - movies = DoubanChain(db).movie_showing(page=page, count=count) + movies = DoubanChain().movie_showing(page=page, count=count) if not movies: return [] medias = [MediaInfo(douban_info=movie) for movie in movies] @@ -65,13 +61,12 @@ def douban_movies(sort: str = "R", tags: str = "", page: int = 1, count: int = 30, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览豆瓣电影信息 """ - movies = DoubanChain(db).douban_discover(mtype=MediaType.MOVIE, - sort=sort, tags=tags, page=page, count=count) + movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE, + sort=sort, tags=tags, page=page, count=count) if not movies: return [] medias = [MediaInfo(douban_info=movie) for movie in movies] @@ -86,13 +81,12 @@ def douban_tvs(sort: str = "R", tags: str = "", page: int = 1, count: int = 30, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览豆瓣剧集信息 """ - tvs = DoubanChain(db).douban_discover(mtype=MediaType.TV, - sort=sort, tags=tags, page=page, count=count) + tvs = DoubanChain().douban_discover(mtype=MediaType.TV, + sort=sort, tags=tags, page=page, count=count) if not tvs: return [] medias = [MediaInfo(douban_info=tv) for tv in tvs] @@ -106,59 +100,54 @@ def douban_tvs(sort: str = "R", @router.get("/movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo]) def movie_top250(page: int = 1, count: int = 30, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览豆瓣剧集信息 """ - movies = DoubanChain(db).movie_top250(page=page, count=count) + movies = DoubanChain().movie_top250(page=page, count=count) return [MediaInfo(douban_info=movie).to_dict() for movie in movies] @router.get("/tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo]) def tv_weekly_chinese(page: int = 1, count: int = 30, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 中国每周剧集口碑榜 """ - tvs = DoubanChain(db).tv_weekly_chinese(page=page, count=count) + tvs = DoubanChain().tv_weekly_chinese(page=page, count=count) return [MediaInfo(douban_info=tv).to_dict() for tv in tvs] @router.get("/tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo]) def tv_weekly_global(page: int = 1, count: int = 30, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 全球每周剧集口碑榜 """ - tvs = DoubanChain(db).tv_weekly_global(page=page, count=count) + tvs = DoubanChain().tv_weekly_global(page=page, count=count) return [MediaInfo(douban_info=tv).to_dict() for tv in tvs] @router.get("/tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo]) def tv_animation(page: int = 1, count: int = 30, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 热门动画剧集 """ - tvs = DoubanChain(db).tv_animation(page=page, count=count) + tvs = DoubanChain().tv_animation(page=page, count=count) return [MediaInfo(douban_info=tv).to_dict() for tv in tvs] @router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo) def douban_info(doubanid: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据豆瓣ID查询豆瓣媒体信息 """ - doubaninfo = DoubanChain(db).douban_info(doubanid=doubanid) + doubaninfo = DoubanChain().douban_info(doubanid=doubanid) if doubaninfo: return MediaInfo(douban_info=doubaninfo).to_dict() else: diff --git a/app/api/endpoints/download.py b/app/api/endpoints/download.py index df164686..9750c00b 100644 --- a/app/api/endpoints/download.py +++ b/app/api/endpoints/download.py @@ -68,12 +68,12 @@ def exists(media_in: schemas.MediaInfo, if media_in.tmdb_id: mediainfo.from_dict(media_in.dict()) elif media_in.douban_id: - context = DoubanChain(db).recognize_by_doubanid(doubanid=media_in.douban_id) + context = DoubanChain().recognize_by_doubanid(doubanid=media_in.douban_id) if context: mediainfo = context.media_info meta = context.meta_info else: - context = MediaChain(db).recognize_by_title(title=f"{media_in.title} {media_in.year}") + context = MediaChain().recognize_by_title(title=f"{media_in.title} {media_in.year}") if context: mediainfo = context.media_info meta = context.meta_info diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index a421994a..019fbbb5 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -74,11 +74,11 @@ def bing_wallpaper() -> Any: @router.get("/tmdb", summary="TMDB电影海报", response_model=schemas.Response) -def tmdb_wallpaper(db: Session = Depends(get_db)) -> Any: +def tmdb_wallpaper() -> Any: """ 获取TMDB电影海报 """ - wallpager = TmdbChain(db).get_random_wallpager() + wallpager = TmdbChain().get_random_wallpager() if wallpager: return schemas.Response( success=True, diff --git a/app/api/endpoints/media.py b/app/api/endpoints/media.py index 2e49f2ab..f2cd8a9e 100644 --- a/app/api/endpoints/media.py +++ b/app/api/endpoints/media.py @@ -20,13 +20,12 @@ router = APIRouter() @router.get("/recognize", summary="识别媒体信息(种子)", response_model=schemas.Context) def recognize(title: str, subtitle: str = None, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据标题、副标题识别媒体信息 """ # 识别媒体信息 - context = MediaChain(db).recognize_by_title(title=title, subtitle=subtitle) + context = MediaChain().recognize_by_title(title=title, subtitle=subtitle) if context: return context.to_dict() return schemas.Context() @@ -34,13 +33,12 @@ def recognize(title: str, @router.get("/recognize_file", summary="识别媒体信息(文件)", response_model=schemas.Context) def recognize(path: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据文件路径识别媒体信息 """ # 识别媒体信息 - context = MediaChain(db).recognize_by_path(path) + context = MediaChain().recognize_by_path(path) if context: return context.to_dict() return schemas.Context() @@ -50,12 +48,11 @@ def recognize(path: str, def search_by_title(title: str, page: int = 1, count: int = 8, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 模糊搜索媒体信息列表 """ - _, medias = MediaChain(db).search(title=title) + _, medias = MediaChain().search(title=title) if medias: return [media.to_dict() for media in medias[(page - 1) * count: page * count]] return [] @@ -85,21 +82,20 @@ def exists(title: str = None, @router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo) def tmdb_info(mediaid: str, type_name: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据媒体ID查询themoviedb或豆瓣媒体信息,type_name: 电影/电视剧 """ mtype = MediaType(type_name) if mediaid.startswith("tmdb:"): - result = TmdbChain(db).tmdb_info(int(mediaid[5:]), mtype) + result = TmdbChain().tmdb_info(int(mediaid[5:]), mtype) return MediaInfo(tmdb_info=result).to_dict() elif mediaid.startswith("douban:"): # 查询豆瓣信息 - doubaninfo = DoubanChain(db).douban_info(doubanid=mediaid[7:]) + doubaninfo = DoubanChain().douban_info(doubanid=mediaid[7:]) if not doubaninfo: return schemas.MediaInfo() - result = DoubanChain(db).recognize_by_doubaninfo(doubaninfo) + result = DoubanChain().recognize_by_doubaninfo(doubaninfo) if result: # TMDB return result.media_info.to_dict() diff --git a/app/api/endpoints/search.py b/app/api/endpoints/search.py index 406032ea..35212238 100644 --- a/app/api/endpoints/search.py +++ b/app/api/endpoints/search.py @@ -40,7 +40,7 @@ def search_by_tmdbid(mediaid: str, elif mediaid.startswith("douban:"): doubanid = mediaid.replace("douban:", "") # 识别豆瓣信息 - context = DoubanChain(db).recognize_by_doubanid(doubanid) + context = DoubanChain().recognize_by_doubanid(doubanid) if not context or not context.media_info or not context.media_info.tmdb_id: return [] torrents = SearchChain(db).search_by_tmdbid(tmdbid=context.media_info.tmdb_id, diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index 6cd2d9ed..6cea9665 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -223,7 +223,7 @@ def execute_command(jobid: str, if not jobid: return schemas.Response(success=False, message="命令不能为空!") if jobid == "subscribe_search": - Scheduler().start(jobid, state = 'R') + Scheduler().start(jobid, state='R') else: Scheduler().start(jobid) - return schemas.Response(success=True) \ No newline at end of file + return schemas.Response(success=True) diff --git a/app/api/endpoints/tmdb.py b/app/api/endpoints/tmdb.py index 8419eda1..06f3acac 100644 --- a/app/api/endpoints/tmdb.py +++ b/app/api/endpoints/tmdb.py @@ -1,25 +1,22 @@ from typing import List, Any from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session from app import schemas from app.chain.tmdb import TmdbChain from app.core.context import MediaInfo from app.core.security import verify_token -from app.db import get_db from app.schemas.types import MediaType router = APIRouter() @router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason]) -def tmdb_seasons(tmdbid: int, db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询themoviedb所有季信息 """ - seasons_info = TmdbChain(db).tmdb_seasons(tmdbid=tmdbid) + seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid) if not seasons_info: return [] else: @@ -29,16 +26,15 @@ def tmdb_seasons(tmdbid: int, db: Session = Depends(get_db), @router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo]) def tmdb_similar(tmdbid: int, type_name: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询类似电影/电视剧,type_name: 电影/电视剧 """ mediatype = MediaType(type_name) if mediatype == MediaType.MOVIE: - tmdbinfos = TmdbChain(db).movie_similar(tmdbid=tmdbid) + tmdbinfos = TmdbChain().movie_similar(tmdbid=tmdbid) elif mediatype == MediaType.TV: - tmdbinfos = TmdbChain(db).tv_similar(tmdbid=tmdbid) + tmdbinfos = TmdbChain().tv_similar(tmdbid=tmdbid) else: return [] if not tmdbinfos: @@ -50,16 +46,15 @@ def tmdb_similar(tmdbid: int, @router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo]) def tmdb_recommend(tmdbid: int, type_name: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询推荐电影/电视剧,type_name: 电影/电视剧 """ mediatype = MediaType(type_name) if mediatype == MediaType.MOVIE: - tmdbinfos = TmdbChain(db).movie_recommend(tmdbid=tmdbid) + tmdbinfos = TmdbChain().movie_recommend(tmdbid=tmdbid) elif mediatype == MediaType.TV: - tmdbinfos = TmdbChain(db).tv_recommend(tmdbid=tmdbid) + tmdbinfos = TmdbChain().tv_recommend(tmdbid=tmdbid) else: return [] if not tmdbinfos: @@ -72,16 +67,15 @@ def tmdb_recommend(tmdbid: int, def tmdb_credits(tmdbid: int, type_name: str, page: int = 1, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询演员阵容,type_name: 电影/电视剧 """ mediatype = MediaType(type_name) if mediatype == MediaType.MOVIE: - tmdbinfos = TmdbChain(db).movie_credits(tmdbid=tmdbid, page=page) + tmdbinfos = TmdbChain().movie_credits(tmdbid=tmdbid, page=page) elif mediatype == MediaType.TV: - tmdbinfos = TmdbChain(db).tv_credits(tmdbid=tmdbid, page=page) + tmdbinfos = TmdbChain().tv_credits(tmdbid=tmdbid, page=page) else: return [] if not tmdbinfos: @@ -92,12 +86,11 @@ def tmdb_credits(tmdbid: int, @router.get("/person/{person_id}", summary="人物详情", response_model=schemas.TmdbPerson) def tmdb_person(person_id: int, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据人物ID查询人物详情 """ - tmdbinfo = TmdbChain(db).person_detail(person_id=person_id) + tmdbinfo = TmdbChain().person_detail(person_id=person_id) if not tmdbinfo: return schemas.TmdbPerson() else: @@ -107,12 +100,11 @@ def tmdb_person(person_id: int, @router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo]) def tmdb_person_credits(person_id: int, page: int = 1, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据人物ID查询人物参演作品 """ - tmdbinfo = TmdbChain(db).person_credits(person_id=person_id, page=page) + tmdbinfo = TmdbChain().person_credits(person_id=person_id, page=page) if not tmdbinfo: return [] else: @@ -124,16 +116,15 @@ def tmdb_movies(sort_by: str = "popularity.desc", with_genres: str = "", with_original_language: str = "", page: int = 1, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览TMDB电影信息 """ - movies = TmdbChain(db).tmdb_discover(mtype=MediaType.MOVIE, - sort_by=sort_by, - with_genres=with_genres, - with_original_language=with_original_language, - page=page) + movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE, + sort_by=sort_by, + with_genres=with_genres, + with_original_language=with_original_language, + page=page) if not movies: return [] return [MediaInfo(tmdb_info=movie).to_dict() for movie in movies] @@ -144,16 +135,15 @@ def tmdb_tvs(sort_by: str = "popularity.desc", with_genres: str = "", with_original_language: str = "", page: int = 1, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览TMDB剧集信息 """ - tvs = TmdbChain(db).tmdb_discover(mtype=MediaType.TV, - sort_by=sort_by, - with_genres=with_genres, - with_original_language=with_original_language, - page=page) + tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV, + sort_by=sort_by, + with_genres=with_genres, + with_original_language=with_original_language, + page=page) if not tvs: return [] return [MediaInfo(tmdb_info=tv).to_dict() for tv in tvs] @@ -161,12 +151,11 @@ def tmdb_tvs(sort_by: str = "popularity.desc", @router.get("/trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo]) def tmdb_trending(page: int = 1, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览TMDB剧集信息 """ - infos = TmdbChain(db).tmdb_trending(page=page) + infos = TmdbChain().tmdb_trending(page=page) if not infos: return [] return [MediaInfo(tmdb_info=info).to_dict() for info in infos] @@ -174,12 +163,11 @@ def tmdb_trending(page: int = 1, @router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode]) def tmdb_season_episodes(tmdbid: int, season: int, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询某季的所有信信息 """ - episodes_info = TmdbChain(db).tmdb_episodes(tmdbid=tmdbid, season=season) + episodes_info = TmdbChain().tmdb_episodes(tmdbid=tmdbid, season=season) if not episodes_info: return [] else: diff --git a/app/api/servarr.py b/app/api/servarr.py index 52b7914e..90df22ca 100644 --- a/app/api/servarr.py +++ b/app/api/servarr.py @@ -301,11 +301,11 @@ def arr_movie_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> A ) tmdbid = term.replace("tmdb:", "") # 查询媒体信息 - mediainfo = MediaChain(db).recognize_media(mtype=MediaType.MOVIE, tmdbid=int(tmdbid)) + mediainfo = MediaChain().recognize_media(mtype=MediaType.MOVIE, tmdbid=int(tmdbid)) if not mediainfo: return [RadarrMovie()] # 查询是否已存在 - exists = MediaChain(db).media_exists(mediainfo=mediainfo) + exists = MediaChain().media_exists(mediainfo=mediainfo) if not exists: # 文件不存在 hasfile = False @@ -581,7 +581,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> # 获取TVDBID if not term.startswith("tvdb:"): - mediainfo = MediaChain(db).recognize_media(meta=MetaInfo(term), + mediainfo = MediaChain().recognize_media(meta=MetaInfo(term), mtype=MediaType.TV) if not mediainfo: return [SonarrSeries()] @@ -593,7 +593,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> tvdbid = int(term.replace("tvdb:", "")) # 查询TVDB信息 - tvdbinfo = MediaChain(db).tvdb_info(tvdbid=tvdbid) + tvdbinfo = MediaChain().tvdb_info(tvdbid=tvdbid) if not tvdbinfo: return [SonarrSeries()] @@ -605,11 +605,11 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> # 根据TVDB查询媒体信息 if not mediainfo: - mediainfo = MediaChain(db).recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')), + mediainfo = MediaChain().recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')), mtype=MediaType.TV) # 查询是否存在 - exists = MediaChain(db).media_exists(mediainfo) + exists = MediaChain().media_exists(mediainfo) if exists: hasfile = True else: diff --git a/app/chain/douban.py b/app/chain/douban.py index 0875a76c..5605f089 100644 --- a/app/chain/douban.py +++ b/app/chain/douban.py @@ -6,11 +6,12 @@ from app.core.context import MediaInfo from app.core.metainfo import MetaInfo from app.log import logger from app.schemas import MediaType +from app.utils.singleton import Singleton -class DoubanChain(ChainBase): +class DoubanChain(ChainBase, metaclass=Singleton): """ - 豆瓣处理链 + 豆瓣处理链,单例运行 """ def recognize_by_doubanid(self, doubanid: str) -> Optional[Context]: diff --git a/app/chain/media.py b/app/chain/media.py index e797525f..8e8acb01 100644 --- a/app/chain/media.py +++ b/app/chain/media.py @@ -1,18 +1,31 @@ +import copy +import time from pathlib import Path +from threading import Lock from typing import Optional, List, Tuple from app.chain import ChainBase from app.core.context import Context, MediaInfo +from app.core.event import eventmanager, Event from app.core.meta import MetaBase from app.core.metainfo import MetaInfo, MetaInfoPath from app.log import logger +from app.schemas.types import EventType, MediaType +from app.utils.singleton import Singleton from app.utils.string import StringUtils -class MediaChain(ChainBase): +recognize_lock = Lock() + + +class MediaChain(ChainBase, metaclass=Singleton): """ - 媒体信息处理链 + 媒体信息处理链,单例运行 """ + # 临时识别标题 + recognize_title: Optional[str] = None + # 临时识别结果 {title, name, year, season, episode} + recognize_temp: Optional[dict] = None def recognize_by_title(self, title: str, subtitle: str = None) -> Optional[Context]: """ @@ -24,14 +37,104 @@ class MediaChain(ChainBase): # 识别媒体信息 mediainfo: MediaInfo = self.recognize_media(meta=metainfo) if not mediainfo: - logger.warn(f'{title} 未识别到媒体信息') - return Context(meta_info=metainfo) + # 偿试使用辅助识别,如果有注册响应事件的话 + if eventmanager.check(EventType.NameRecognize): + logger.info(f'请求辅助识别,标题:{title} ...') + mediainfo = self.recognize_help(title=title, org_meta=metainfo) + if not mediainfo: + logger.warn(f'{title} 未识别到媒体信息') + return Context(meta_info=metainfo) + # 识别成功 logger.info(f'{title} 识别到媒体信息:{mediainfo.type.value} {mediainfo.title_year}') # 更新媒体图片 self.obtain_images(mediainfo=mediainfo) # 返回上下文 return Context(meta_info=metainfo, media_info=mediainfo) + def recognize_help(self, title: str, org_meta: MetaBase) -> Optional[MediaInfo]: + """ + 请求辅助识别,返回媒体信息 + :param title: 标题 + :param org_meta: 原始元数据 + """ + with recognize_lock: + self.recognize_temp = None + self.recognize_title = title + + # 发送请求事件 + eventmanager.send_event( + EventType.NameRecognize, + { + 'title': title, + } + ) + # 每0.5秒循环一次,等待结果,直到10秒后超时 + for i in range(10): + if self.recognize_temp is not None: + break + time.sleep(0.5) + # 加锁 + with recognize_lock: + mediainfo = None + if not self.recognize_temp or self.recognize_title != title: + # 没有识别结果或者识别标题已改变 + return None + # 有识别结果 + meta_dict = copy.deepcopy(self.recognize_temp) + logger.info(f'获取到辅助识别结果:{meta_dict}') + if meta_dict.get("name") == org_meta.name and meta_dict.get("year") == org_meta.year: + logger.info(f'辅助识别结果与原始识别结果一致') + else: + logger.info(f'辅助识别结果与原始识别结果不一致,重新匹配媒体信息 ...') + org_meta.name = meta_dict.get("name") + org_meta.year = meta_dict.get("year") + org_meta.begin_season = meta_dict.get("season") + org_meta.begin_episode = meta_dict.get("episode") + if org_meta.begin_season or org_meta.begin_episode: + org_meta.type = MediaType.TV + # 重新识别 + mediainfo = self.recognize_media(meta=org_meta) + return mediainfo + + @eventmanager.register(EventType.NameRecognizeResult) + def recognize_result(self, event: Event): + """ + 监控识别结果事件,获取辅助识别结果,结果格式:{title, name, year, season, episode} + """ + if not event: + return + event_data = event.event_data or {} + # 加锁 + with recognize_lock: + # 不是原标题的结果不要 + if event_data.get("title") != self.recognize_title: + return + # 标志收到返回 + self.recognize_temp = {} + # 处理数据格式 + file_title, file_year, season_number, episode_number = None, None, None, None + if event_data.get("name"): + file_title = str(event_data["name"]).split("/")[0].strip().replace(".", " ") + if event_data.get("year"): + file_year = str(event_data["year"]).split("/")[0].strip() + if event_data.get("season") and str(event_data["season"]).isdigit(): + season_number = int(event_data["season"]) + if event_data.get("episode") and str(event_data["episode"]).isdigit(): + episode_number = int(event_data["episode"]) + if not file_title: + return + if file_title == 'Unknown': + return + if not str(file_year).isdigit(): + file_year = None + # 结果赋值 + self.recognize_temp = { + "name": file_title, + "year": file_year, + "season": season_number, + "episode": episode_number + } + def recognize_by_path(self, path: str) -> Optional[Context]: """ 根据文件路径识别媒体信息 @@ -43,8 +146,13 @@ class MediaChain(ChainBase): # 识别媒体信息 mediainfo = self.recognize_media(meta=file_meta) if not mediainfo: - logger.warn(f'{path} 未识别到媒体信息') - return Context(meta_info=file_meta) + # 偿试使用辅助识别,如果有注册响应事件的话 + if eventmanager.check(EventType.NameRecognize): + logger.info(f'请求辅助识别,标题:{file_path.name} ...') + mediainfo = self.recognize_help(title=path, org_meta=file_meta) + if not mediainfo: + logger.warn(f'{path} 未识别到媒体信息') + return Context(meta_info=file_meta) logger.info(f'{path} 识别到媒体信息:{mediainfo.type.value} {mediainfo.title_year}') # 更新媒体图片 self.obtain_images(mediainfo=mediainfo) diff --git a/app/chain/message.py b/app/chain/message.py index a4cad49e..411f875c 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -32,7 +32,7 @@ class MessageChain(ChainBase): self.downloadchain = DownloadChain(self._db) self.subscribechain = SubscribeChain(self._db) self.searchchain = SearchChain(self._db) - self.medtachain = MediaChain(self._db) + self.medtachain = MediaChain() self.torrent = TorrentHelper() self.eventmanager = EventManager() self.torrenthelper = TorrentHelper() diff --git a/app/chain/tmdb.py b/app/chain/tmdb.py index a798fa53..0d49c2b1 100644 --- a/app/chain/tmdb.py +++ b/app/chain/tmdb.py @@ -12,7 +12,7 @@ from app.utils.singleton import Singleton class TmdbChain(ChainBase, metaclass=Singleton): """ - TheMovieDB处理链 + TheMovieDB处理链,单例运行 """ def tmdb_discover(self, mtype: MediaType, sort_by: str, with_genres: str, diff --git a/app/chain/transfer.py b/app/chain/transfer.py index cf6dc963..c217c96e 100644 --- a/app/chain/transfer.py +++ b/app/chain/transfer.py @@ -41,8 +41,8 @@ class TransferChain(ChainBase): self.downloadhis = DownloadHistoryOper(self._db) self.transferhis = TransferHistoryOper(self._db) self.progress = ProgressHelper() - self.mediachain = MediaChain(self._db) - self.tmdbchain = TmdbChain(self._db) + self.mediachain = MediaChain() + self.tmdbchain = TmdbChain() self.systemconfig = SystemConfigOper() def process(self) -> bool: diff --git a/app/command.py b/app/command.py index 4e4be4dc..dc903873 100644 --- a/app/command.py +++ b/app/command.py @@ -1,3 +1,4 @@ +import importlib import traceback from threading import Thread, Event from typing import Any, Union, Dict @@ -175,10 +176,24 @@ class Command(metaclass=Singleton): for handler in handlers: try: names = handler.__qualname__.split(".") - if names[0] == "Command": - self.command_event(event) + [class_name, method_name] = names + if class_name in self.pluginmanager.get_plugin_ids(): + # 插件事件 + self.pluginmanager.run_plugin_method(class_name, method_name, event) else: - self.pluginmanager.run_plugin_method(names[0], names[1], event) + # 检查全局变量中是否存在 + if class_name not in globals(): + # 导入模块,除了插件和Command本身,只有chain能响应事件 + module = importlib.import_module( + f"app.chain.{class_name[:-5].lower()}" + ) + class_obj = getattr(module, class_name)() + else: + # 通过类名创建类实例 + class_obj = globals()[class_name]() + # 检查类是否存在并调用方法 + if hasattr(class_obj, method_name): + getattr(class_obj, method_name)(event) except Exception as e: logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}") diff --git a/app/core/event.py b/app/core/event.py index 79e63078..da09bd0d 100644 --- a/app/core/event.py +++ b/app/core/event.py @@ -32,6 +32,12 @@ class EventManager(metaclass=Singleton): except Empty: return None, [] + def check(self, etype: EventType): + """ + 检查事件是否存在响应 + """ + return etype.value in self._handlers + def add_event_listener(self, etype: EventType, handler: type): """ 注册事件处理 diff --git a/app/core/meta/metabase.py b/app/core/meta/metabase.py index 296ae2bb..116b0203 100644 --- a/app/core/meta/metabase.py +++ b/app/core/meta/metabase.py @@ -87,6 +87,17 @@ class MetaBase(object): return self.cn_name return "" + @name.setter + def name(self, name: str): + """ + 设置名称 + """ + if StringUtils.is_all_chinese(name): + self.cn_name = name + else: + self.en_name = name + self.cn_name = None + def init_subtitle(self, title_text: str): """ 副标题识别 diff --git a/app/core/plugin.py b/app/core/plugin.py index 0ff14907..3bbbe0fb 100644 --- a/app/core/plugin.py +++ b/app/core/plugin.py @@ -177,6 +177,12 @@ class PluginManager(metaclass=Singleton): return None return getattr(self._running_plugins[pid], method)(*args, **kwargs) + def get_plugin_ids(self) -> List[str]: + """ + 获取所有插件ID + """ + return list(self._plugins.keys()) + def get_plugin_apps(self) -> List[dict]: """ 获取所有插件信息 diff --git a/app/plugins/chatgpt/__init__.py b/app/plugins/chatgpt/__init__.py index 6db0438a..c0316b2d 100644 --- a/app/plugins/chatgpt/__init__.py +++ b/app/plugins/chatgpt/__init__.py @@ -1,7 +1,8 @@ from typing import Any, List, Dict, Tuple from app.core.config import settings -from app.core.event import eventmanager +from app.core.event import eventmanager, Event +from app.log import logger from app.plugins import _PluginBase from app.plugins.chatgpt.openai import OpenAi from app.schemas.types import EventType @@ -33,6 +34,7 @@ class ChatGPT(_PluginBase): openai = None _enabled = False _proxy = False + _recognize = False _openai_url = None _openai_key = None @@ -40,6 +42,7 @@ class ChatGPT(_PluginBase): if config: self._enabled = config.get("enabled") self._proxy = config.get("proxy") + self._recognize = config.get("recognize") self._openai_url = config.get("openai_url") self._openai_key = config.get("openai_key") self.openai = OpenAi(api_key=self._openai_key, api_url=self._openai_url, @@ -70,7 +73,7 @@ class ChatGPT(_PluginBase): 'component': 'VCol', 'props': { 'cols': 12, - 'md': 6 + 'md': 4 }, 'content': [ { @@ -86,7 +89,7 @@ class ChatGPT(_PluginBase): 'component': 'VCol', 'props': { 'cols': 12, - 'md': 6 + 'md': 4 }, 'content': [ { @@ -97,6 +100,22 @@ class ChatGPT(_PluginBase): } } ] + }, + { + 'component': 'VCol', + 'props': { + 'cols': 12, + 'md': 4 + }, + 'content': [ + { + 'component': 'VSwitch', + 'props': { + 'model': 'recognize', + 'label': '辅助识别', + } + } + ] } ] }, @@ -143,6 +162,7 @@ class ChatGPT(_PluginBase): ], { "enabled": False, "proxy": False, + "recognize": False, "openai_url": "https://api.openai.com", "openai_key": "" } @@ -151,10 +171,12 @@ class ChatGPT(_PluginBase): pass @eventmanager.register(EventType.UserMessage) - def talk(self, event): + def talk(self, event: Event): """ 监听用户消息,获取ChatGPT回复 """ + if not self._enabled: + return if not self.openai: return text = event.event_data.get("text") @@ -166,6 +188,34 @@ class ChatGPT(_PluginBase): if response: self.post_message(channel=channel, title=response, userid=userid) + @eventmanager.register(EventType.NameRecognize) + def recognize(self, event: Event): + """ + 监听识别事件,使用ChatGPT辅助识别名称 + """ + if not self._enabled: + return + if not self.openai: + return + if not event.event_data: + return + title = event.event_data.get("title") + if not title: + return + response = self.openai.get_media_name(filename=title) + logger.info(f"ChatGPT辅助识别结果:{response}") + if response: + eventmanager.send_event( + EventType.NameRecognizeResult, + { + 'title': title, + 'name': response.get("title"), + 'year': response.get("year"), + 'season': response.get("season"), + 'episode': response.get("episode") + } + ) + def stop_service(self): """ 退出插件 diff --git a/app/plugins/dirmonitor/__init__.py b/app/plugins/dirmonitor/__init__.py index 23765b20..6718e9b4 100644 --- a/app/plugins/dirmonitor/__init__.py +++ b/app/plugins/dirmonitor/__init__.py @@ -97,7 +97,7 @@ class DirMonitor(_PluginBase): self.transferhis = TransferHistoryOper(self.db) self.downloadhis = DownloadHistoryOper(self.db) self.transferchian = TransferChain(self.db) - self.tmdbchain = TmdbChain(self.db) + self.tmdbchain = TmdbChain() # 清空配置 self._dirconf = {} self._transferconf = {} diff --git a/app/plugins/downloadingmsg/__init__.py b/app/plugins/downloadingmsg/__init__.py index 20419e37..73323b2b 100644 --- a/app/plugins/downloadingmsg/__init__.py +++ b/app/plugins/downloadingmsg/__init__.py @@ -156,7 +156,7 @@ class DownloadingMsg(_PluginBase): channel_value = downloadhis.channel else: try: - context = MediaChain(self.db).recognize_by_title(title=torrent.title) + context = MediaChain().recognize_by_title(title=torrent.title) if not context or not context.media_info: continue media_info = context.media_info diff --git a/app/plugins/personmeta/__init__.py b/app/plugins/personmeta/__init__.py index c4185ec0..3be06683 100644 --- a/app/plugins/personmeta/__init__.py +++ b/app/plugins/personmeta/__init__.py @@ -67,7 +67,7 @@ class PersonMeta(_PluginBase): _remove_nozh = False def init_plugin(self, config: dict = None): - self.tmdbchain = TmdbChain(self.db) + self.tmdbchain = TmdbChain() self.mschain = MediaServerChain(self.db) if config: self._enabled = config.get("enabled") diff --git a/app/schemas/types.py b/app/schemas/types.py index bcac684d..bd23fa93 100644 --- a/app/schemas/types.py +++ b/app/schemas/types.py @@ -40,6 +40,10 @@ class EventType(Enum): UserMessage = "user.message" # 通知消息 NoticeMessage = "notice.message" + # 名称识别请求 + NameRecognize = "name.recognize" + # 名称识别结果 + NameRecognizeResult = "name.recognize.result" # 系统配置Key字典