diff --git a/app/api/endpoints/dashboard.py b/app/api/endpoints/dashboard.py index 3c660adc..f5398642 100644 --- a/app/api/endpoints/dashboard.py +++ b/app/api/endpoints/dashboard.py @@ -19,11 +19,12 @@ router = APIRouter() @router.get("/statistic", summary="媒体数量统计", response_model=schemas.Statistic) -def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any: +def statistic(db: Session = Depends(get_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询媒体数量统计信息 """ - media_statistic = DashboardChain().media_statistic() + media_statistic = DashboardChain(db).media_statistic() if media_statistic: return schemas.Statistic( movie_count=media_statistic.movie_count, @@ -61,11 +62,12 @@ def processes(_: schemas.TokenPayload = Depends(verify_token)) -> Any: @router.get("/downloader", summary="下载器信息", response_model=schemas.DownloaderInfo) -def downloader(_: schemas.TokenPayload = Depends(verify_token)) -> Any: +def downloader(db: Session = Depends(get_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询下载器信息 """ - transfer_info = DashboardChain().downloader_info() + transfer_info = DashboardChain(db).downloader_info() free_space = SystemUtils.free_space(Path(settings.DOWNLOAD_PATH)) return schemas.DownloaderInfo( download_speed=transfer_info.download_speed, diff --git a/app/api/endpoints/douban.py b/app/api/endpoints/douban.py index 8d66cce8..a1e7c6bf 100644 --- a/app/api/endpoints/douban.py +++ b/app/api/endpoints/douban.py @@ -1,12 +1,14 @@ from typing import List, Any from fastapi import APIRouter, Depends, Response +from sqlalchemy.orm import Session from app import schemas from app.chain.douban import DoubanChain from app.core.config import settings from app.core.context import MediaInfo from app.core.security import verify_token +from app.db import get_db from app.schemas import MediaType from app.utils.http import RequestUtils @@ -30,12 +32,13 @@ def douban_img(imgurl: str) -> Any: @router.get("/recognize/{doubanid}", summary="豆瓣ID识别", response_model=schemas.Context) def recognize_doubanid(doubanid: str, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据豆瓣ID识别媒体信息 """ # 识别媒体信息 - context = DoubanChain().recognize_by_doubanid(doubanid=doubanid) + context = DoubanChain(db).recognize_by_doubanid(doubanid=doubanid) if context: return context.to_dict() else: @@ -47,12 +50,13 @@ def douban_movies(sort: str = "R", tags: str = "", page: int = 1, count: int = 30, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览豆瓣电影信息 """ - movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE, - sort=sort, tags=tags, page=page, count=count) + movies = DoubanChain(db).douban_discover(mtype=MediaType.MOVIE, + sort=sort, tags=tags, page=page, count=count) if not movies: return [] medias = [MediaInfo(douban_info=movie) for movie in movies] @@ -67,12 +71,13 @@ def douban_tvs(sort: str = "R", tags: str = "", page: int = 1, count: int = 30, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览豆瓣剧集信息 """ - tvs = DoubanChain().douban_discover(mtype=MediaType.TV, - sort=sort, tags=tags, page=page, count=count) + tvs = DoubanChain(db).douban_discover(mtype=MediaType.TV, + sort=sort, tags=tags, page=page, count=count) if not tvs: return [] medias = [MediaInfo(douban_info=tv) for tv in tvs] @@ -86,42 +91,47 @@ def douban_tvs(sort: str = "R", @router.get("/movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo]) def movie_top250(page: int = 1, count: int = 30, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览豆瓣剧集信息 """ - movies = DoubanChain().movie_top250(page=page, count=count) + movies = DoubanChain(db).movie_top250(page=page, count=count) return [MediaInfo(douban_info=movie).to_dict() for movie in movies] @router.get("/tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo]) def tv_weekly_chinese(page: int = 1, count: int = 30, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 中国每周剧集口碑榜 """ - tvs = DoubanChain().tv_weekly_chinese(page=page, count=count) + tvs = DoubanChain(db).tv_weekly_chinese(page=page, count=count) return [MediaInfo(douban_info=tv).to_dict() for tv in tvs] @router.get("/tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo]) def tv_weekly_global(page: int = 1, count: int = 30, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 全球每周剧集口碑榜 """ - tvs = DoubanChain().tv_weekly_global(page=page, count=count) + tvs = DoubanChain(db).tv_weekly_global(page=page, count=count) return [MediaInfo(douban_info=tv).to_dict() for tv in tvs] @router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo) -def douban_info(doubanid: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any: +def douban_info(doubanid: str, + db: Session = Depends(get_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据豆瓣ID查询豆瓣媒体信息 """ - doubaninfo = DoubanChain().douban_info(doubanid=doubanid) + doubaninfo = DoubanChain(db).douban_info(doubanid=doubanid) if doubaninfo: return MediaInfo(douban_info=doubaninfo).to_dict() else: diff --git a/app/api/endpoints/download.py b/app/api/endpoints/download.py index f8328c99..d8842951 100644 --- a/app/api/endpoints/download.py +++ b/app/api/endpoints/download.py @@ -1,6 +1,7 @@ from typing import Any, List from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session from app import schemas from app.chain.douban import DoubanChain @@ -9,6 +10,7 @@ from app.chain.media import MediaChain from app.core.context import MediaInfo, Context, TorrentInfo from app.core.metainfo import MetaInfo from app.core.security import verify_token +from app.db import get_db from app.db.models.user import User from app.db.userauth import get_current_active_superuser from app.schemas import NotExistMediaInfo, MediaType @@ -18,18 +20,20 @@ router = APIRouter() @router.get("/", summary="正在下载", response_model=List[schemas.DownloadingTorrent]) def read_downloading( + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询正在下载的任务 """ - return DownloadChain().downloading() + return DownloadChain(db).downloading() @router.post("/", summary="添加下载", response_model=schemas.Response) def add_downloading( media_in: schemas.MediaInfo, torrent_in: schemas.TorrentInfo, - current_user: User = Depends(get_current_active_superuser)) -> Any: + db: Session = Depends(get_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 添加下载任务 """ @@ -47,7 +51,7 @@ def add_downloading( media_info=mediainfo, torrent_info=torrentinfo ) - did = DownloadChain().download_single(context=context) + did = DownloadChain(db).download_single(context=context) return schemas.Response(success=True if did else False, data={ "download_id": did }) @@ -55,6 +59,7 @@ def add_downloading( @router.post("/notexists", summary="查询缺失媒体信息", response_model=List[NotExistMediaInfo]) def exists(media_in: schemas.MediaInfo, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询缺失媒体信息 @@ -65,19 +70,19 @@ def exists(media_in: schemas.MediaInfo, if media_in.tmdb_id: mediainfo.from_dict(media_in.dict()) elif media_in.douban_id: - context = DoubanChain().recognize_by_doubanid(doubanid=media_in.douban_id) + context = DoubanChain(db).recognize_by_doubanid(doubanid=media_in.douban_id) if context: mediainfo = context.media_info meta = context.meta_info else: - context = MediaChain().recognize_by_title(title=f"{media_in.title} {media_in.year}") + context = MediaChain(db).recognize_by_title(title=f"{media_in.title} {media_in.year}") if context: mediainfo = context.media_info meta = context.meta_info # 查询缺失信息 if not mediainfo or not mediainfo.tmdb_id: raise HTTPException(status_code=404, detail="媒体信息不存在") - exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo) + exist_flag, no_exists = DownloadChain(db).get_no_exists_info(meta=meta, mediainfo=mediainfo) if mediainfo.type == MediaType.MOVIE: # 电影已存在时返回空列表,存在时返回空对像列表 return [] if exist_flag else [NotExistMediaInfo()] @@ -90,31 +95,34 @@ def exists(media_in: schemas.MediaInfo, @router.put("/{hashString}/start", summary="开始任务", response_model=schemas.Response) def start_downloading( hashString: str, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 开如下载任务 """ - ret = DownloadChain().set_downloading(hashString, "start") + ret = DownloadChain(db).set_downloading(hashString, "start") return schemas.Response(success=True if ret else False) @router.put("/{hashString}/stop", summary="暂停任务", response_model=schemas.Response) def stop_downloading( hashString: str, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 控制下载任务 """ - ret = DownloadChain().set_downloading(hashString, "stop") + ret = DownloadChain(db).set_downloading(hashString, "stop") return schemas.Response(success=True if ret else False) @router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response) def remove_downloading( hashString: str, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 控制下载任务 """ - ret = DownloadChain().remove_downloading(hashString) + ret = DownloadChain(db).remove_downloading(hashString) return schemas.Response(success=True if ret else False) diff --git a/app/api/endpoints/history.py b/app/api/endpoints/history.py index f8ffa164..ab4396d8 100644 --- a/app/api/endpoints/history.py +++ b/app/api/endpoints/history.py @@ -74,7 +74,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory, if not history: return schemas.Response(success=False, msg="记录不存在") # 册除文件 - TransferChain().delete_files(Path(history.dest)) + TransferChain(db).delete_files(Path(history.dest)) # 删除记录 TransferHistory.delete(db, history_in.id) return schemas.Response(success=True) @@ -84,12 +84,13 @@ def delete_transfer_history(history_in: schemas.TransferHistory, def redo_transfer_history(history_in: schemas.TransferHistory, mtype: str, new_tmdbid: int, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 历史记录重新转移 """ hash_str = history_in.download_hash - result = TransferChain().process(f"{hash_str} {new_tmdbid}|{mtype}") + result = TransferChain(db).process(f"{hash_str} {new_tmdbid}|{mtype}") if result: return schemas.Response(success=True) else: diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index 80cbe68f..a24d0438 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -36,7 +36,7 @@ async def login_access_token( if not user: # 请求协助认证 logger.warn("登录用户本地不匹配,尝试辅助认证 ...") - token = UserChain().user_authenticate(form_data.username, form_data.password) + token = UserChain(db).user_authenticate(form_data.username, form_data.password) if not token: raise HTTPException(status_code=401, detail="用户名或密码不正确") else: @@ -83,11 +83,11 @@ def bing_wallpaper() -> Any: @router.get("/tmdb", summary="TMDB电影海报", response_model=schemas.Response) -def tmdb_wallpaper() -> Any: +def tmdb_wallpaper(db: Session = Depends(get_db)) -> Any: """ 获取TMDB电影海报 """ - infos = TmdbChain().tmdb_trending() + infos = TmdbChain(db).tmdb_trending() if infos: # 随机一个电影 while True: diff --git a/app/api/endpoints/media.py b/app/api/endpoints/media.py index 07aec392..e523b6b9 100644 --- a/app/api/endpoints/media.py +++ b/app/api/endpoints/media.py @@ -20,12 +20,13 @@ router = APIRouter() @router.get("/recognize", summary="识别媒体信息", response_model=schemas.Context) def recognize(title: str, subtitle: str = None, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据标题、副标题识别媒体信息 """ # 识别媒体信息 - context = MediaChain().recognize_by_title(title=title, subtitle=subtitle) + context = MediaChain(db).recognize_by_title(title=title, subtitle=subtitle) if context: return context.to_dict() return schemas.Context() @@ -35,11 +36,12 @@ def recognize(title: str, def search_by_title(title: str, page: int = 1, count: int = 8, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 模糊搜索媒体信息列表 """ - _, medias = MediaChain().search(title=title) + _, medias = MediaChain(db).search(title=title) if medias: return [media.to_dict() for media in medias[(page - 1) * count: page * count]] return [] @@ -69,20 +71,21 @@ def exists(title: str = None, @router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo) def tmdb_info(mediaid: str, type_name: str, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据媒体ID查询themoviedb或豆瓣媒体信息,type_name: 电影/电视剧 """ mtype = MediaType(type_name) if mediaid.startswith("tmdb:"): - result = TmdbChain().tmdb_info(int(mediaid[5:]), mtype) + result = TmdbChain(db).tmdb_info(int(mediaid[5:]), mtype) return MediaInfo(tmdb_info=result).to_dict() elif mediaid.startswith("douban:"): # 查询豆瓣信息 - doubaninfo = DoubanChain().douban_info(doubanid=mediaid[7:]) + doubaninfo = DoubanChain(db).douban_info(doubanid=mediaid[7:]) if not doubaninfo: return schemas.MediaInfo() - result = DoubanChain().recognize_by_doubaninfo(doubaninfo) + result = DoubanChain(db).recognize_by_doubaninfo(doubaninfo) if result: # TMDB return result.media_info.to_dict() diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index bdf683bb..9886ce76 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -19,22 +19,23 @@ from app.schemas.types import SystemConfigKey, NotificationType router = APIRouter() -def start_message_chain(body: Any, form: Any, args: Any): +def start_message_chain(db: Session, body: Any, form: Any, args: Any): """ 启动链式任务 """ - MessageChain().process(body=body, form=form, args=args) + MessageChain(db).process(body=body, form=form, args=args) @router.post("/", summary="接收用户消息", response_model=schemas.Response) -async def user_message(background_tasks: BackgroundTasks, request: Request): +async def user_message(background_tasks: BackgroundTasks, request: Request, + db: Session = Depends(get_db)): """ 用户消息响应 """ body = await request.body() form = await request.form() args = request.query_params - background_tasks.add_task(start_message_chain, body, form, args) + background_tasks.add_task(start_message_chain, db, body, form, args) return schemas.Response(success=True) diff --git a/app/api/endpoints/rss.py b/app/api/endpoints/rss.py index b7a7191c..10674527 100644 --- a/app/api/endpoints/rss.py +++ b/app/api/endpoints/rss.py @@ -15,11 +15,11 @@ from app.schemas import MediaType router = APIRouter() -def start_rss_refresh(rssid: int = None): +def start_rss_refresh(db: Session, rssid: int = None): """ 启动自定义订阅刷新 """ - RssChain().refresh(rssid=rssid, manual=True) + RssChain(db).refresh(rssid=rssid, manual=True) @router.get("/", summary="所有自定义订阅", response_model=List[schemas.Rss]) @@ -36,6 +36,7 @@ def read_rsses( def create_rss( *, rss_in: schemas.Rss, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token) ) -> Any: """ @@ -45,7 +46,7 @@ def create_rss( mtype = MediaType(rss_in.type) else: mtype = None - rssid, errormsg = RssChain().add( + rssid, errormsg = RssChain(db).add( mtype=mtype, **rss_in.dict() ) @@ -100,11 +101,13 @@ def preview_rss( def refresh_rss( rssid: int, background_tasks: BackgroundTasks, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据ID刷新自定义订阅 """ background_tasks.add_task(start_rss_refresh, + db=db, rssid=rssid) return schemas.Response(success=True) diff --git a/app/api/endpoints/search.py b/app/api/endpoints/search.py index da7081e7..a65850ad 100644 --- a/app/api/endpoints/search.py +++ b/app/api/endpoints/search.py @@ -1,28 +1,32 @@ from typing import List, Any from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session from app import schemas from app.chain.douban import DoubanChain from app.chain.search import SearchChain from app.core.security import verify_token +from app.db import get_db from app.schemas.types import MediaType router = APIRouter() @router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context]) -async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def search_latest(db: Session = Depends(get_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询搜索结果 """ - torrents = SearchChain().last_search_results() + torrents = SearchChain(db).last_search_results() return [torrent.to_dict() for torrent in torrents] @router.get("/media/{mediaid}", summary="精确搜索资源", response_model=List[schemas.Context]) def search_by_tmdbid(mediaid: str, mtype: str = None, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/ @@ -31,15 +35,15 @@ def search_by_tmdbid(mediaid: str, tmdbid = int(mediaid.replace("tmdb:", "")) if mtype: mtype = MediaType(mtype) - torrents = SearchChain().search_by_tmdbid(tmdbid=tmdbid, mtype=mtype) + torrents = SearchChain(db).search_by_tmdbid(tmdbid=tmdbid, mtype=mtype) elif mediaid.startswith("douban:"): doubanid = mediaid.replace("douban:", "") # 识别豆瓣信息 - context = DoubanChain().recognize_by_doubanid(doubanid) + context = DoubanChain(db).recognize_by_doubanid(doubanid) if not context or not context.media_info or not context.media_info.tmdb_id: raise HTTPException(status_code=404, detail="无法识别TMDB媒体信息!") - torrents = SearchChain().search_by_tmdbid(tmdbid=context.media_info.tmdb_id, - mtype=context.media_info.type) + torrents = SearchChain(db).search_by_tmdbid(tmdbid=context.media_info.tmdb_id, + mtype=context.media_info.type) else: return [] return [torrent.to_dict() for torrent in torrents] @@ -49,9 +53,10 @@ def search_by_tmdbid(mediaid: str, async def search_by_title(keyword: str = None, page: int = 0, site: int = None, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源 """ - torrents = SearchChain().search_by_title(title=keyword, page=page, site=site) + torrents = SearchChain(db).search_by_title(title=keyword, page=page, site=site) return [torrent.to_dict() for torrent in torrents] diff --git a/app/api/endpoints/site.py b/app/api/endpoints/site.py index d2d6ae52..53cac473 100644 --- a/app/api/endpoints/site.py +++ b/app/api/endpoints/site.py @@ -19,11 +19,11 @@ from app.utils.string import StringUtils router = APIRouter() -def start_cookiecloud_sync(): +def start_cookiecloud_sync(db: Session): """ 后台启动CookieCloud站点同步 """ - CookieCloudChain().process(manual=True) + CookieCloudChain(db).process(manual=True) @router.get("/", summary="所有站点", response_model=List[schemas.Site]) @@ -67,11 +67,12 @@ def delete_site( @router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response) def cookie_cloud_sync(background_tasks: BackgroundTasks, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 运行CookieCloud同步站点信息 """ - background_tasks.add_task(start_cookiecloud_sync) + background_tasks.add_task(start_cookiecloud_sync, db) return schemas.Response(success=True, message="CookieCloud同步任务已启动!") @@ -83,7 +84,7 @@ def cookie_cloud_sync(db: Session = Depends(get_db), """ Site.reset(db) SystemConfigOper(db).set(SystemConfigKey.IndexerSites, []) - CookieCloudChain().process(manual=True) + CookieCloudChain(db).process(manual=True) return schemas.Response(success=True, message="站点已重置!") @@ -105,9 +106,9 @@ def update_cookie( detail=f"站点 {site_id} 不存在!", ) # 更新Cookie - state, message = SiteChain().update_cookie(site_info=site_info, - username=username, - password=password) + state, message = SiteChain(db).update_cookie(site_info=site_info, + username=username, + password=password) return schemas.Response(success=state, message=message) @@ -124,7 +125,7 @@ def test_site(site_id: int, status_code=404, detail=f"站点 {site_id} 不存在", ) - status, message = SiteChain().test(site.domain) + status, message = SiteChain(db).test(site.domain) return schemas.Response(success=status, message=message) @@ -162,7 +163,7 @@ def site_resource(site_id: int, keyword: str = None, status_code=404, detail=f"站点 {site_id} 不存在", ) - torrents = SearchChain().browse(site.domain, keyword) + torrents = SearchChain(db).browse(site.domain, keyword) if not torrents: return [] return [torrent.to_dict() for torrent in torrents] diff --git a/app/api/endpoints/subscribe.py b/app/api/endpoints/subscribe.py index ed43fb69..38bf4c0c 100644 --- a/app/api/endpoints/subscribe.py +++ b/app/api/endpoints/subscribe.py @@ -17,20 +17,20 @@ from app.schemas.types import MediaType router = APIRouter() -def start_subscribe_add(title: str, year: str, +def start_subscribe_add(db: Session, title: str, year: str, mtype: MediaType, tmdbid: int, season: int, username: str): """ 启动订阅任务 """ - SubscribeChain().add(title=title, year=year, - mtype=mtype, tmdbid=tmdbid, season=season, username=username) + SubscribeChain(db).add(title=title, year=year, + mtype=mtype, tmdbid=tmdbid, season=season, username=username) -def start_subscribe_search(sid: Optional[int], state: Optional[str]): +def start_subscribe_search(db: Session, sid: Optional[int], state: Optional[str]): """ 启动订阅搜索任务 """ - SubscribeChain().search(sid=sid, state=state, manual=True) + SubscribeChain(db).search(sid=sid, state=state, manual=True) @router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe]) @@ -51,6 +51,7 @@ def read_subscribes( def create_subscribe( *, subscribe_in: schemas.Subscribe, + db: Session = Depends(get_db), current_user: User = Depends(get_current_active_user), ) -> Any: """ @@ -66,15 +67,15 @@ def create_subscribe( title = subscribe_in.name else: title = None - sid, message = SubscribeChain().add(mtype=mtype, - title=title, - year=subscribe_in.year, - tmdbid=subscribe_in.tmdbid, - season=subscribe_in.season, - doubanid=subscribe_in.doubanid, - username=current_user.name, - best_version=subscribe_in.best_version, - exist_ok=True) + sid, message = SubscribeChain(db).add(mtype=mtype, + title=title, + year=subscribe_in.year, + tmdbid=subscribe_in.tmdbid, + season=subscribe_in.season, + doubanid=subscribe_in.doubanid, + username=current_user.name, + best_version=subscribe_in.best_version, + exist_ok=True) return schemas.Response(success=True if sid else False, message=message, data={ "id": sid }) @@ -171,6 +172,7 @@ def delete_subscribe( @router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response) async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, + db: Session = Depends(get_db), authorization: str = Header(None)) -> Any: """ Jellyseerr/Overseerr订阅 @@ -198,6 +200,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, # 添加订阅 if media_type == MediaType.MOVIE: background_tasks.add_task(start_subscribe_add, + db=db, mtype=media_type, tmdbid=tmdbId, title=subject, @@ -212,6 +215,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, break for season in seasons: background_tasks.add_task(start_subscribe_add, + db=db, mtype=media_type, tmdbid=tmdbId, title=subject, @@ -224,11 +228,12 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, @router.get("/refresh", summary="刷新订阅", response_model=schemas.Response) def refresh_subscribes( + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 刷新所有订阅 """ - SubscribeChain().refresh() + SubscribeChain(db).refresh() return schemas.Response(success=True) @@ -236,20 +241,22 @@ def refresh_subscribes( def search_subscribe( subscribe_id: int, background_tasks: BackgroundTasks, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 搜索所有订阅 """ - background_tasks.add_task(start_subscribe_search, sid=subscribe_id, state=None) + background_tasks.add_task(start_subscribe_search, db=db, sid=subscribe_id, state=None) return schemas.Response(success=True) @router.get("/search", summary="搜索所有订阅", response_model=schemas.Response) def search_subscribes( background_tasks: BackgroundTasks, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 搜索所有订阅 """ - background_tasks.add_task(start_subscribe_search, sid=None, state='R') + background_tasks.add_task(start_subscribe_search, db=db, sid=None, state='R') return schemas.Response(success=True) diff --git a/app/api/endpoints/tmdb.py b/app/api/endpoints/tmdb.py index 06f3acac..8419eda1 100644 --- a/app/api/endpoints/tmdb.py +++ b/app/api/endpoints/tmdb.py @@ -1,22 +1,25 @@ from typing import List, Any from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session from app import schemas from app.chain.tmdb import TmdbChain from app.core.context import MediaInfo from app.core.security import verify_token +from app.db import get_db from app.schemas.types import MediaType router = APIRouter() @router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason]) -def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any: +def tmdb_seasons(tmdbid: int, db: Session = Depends(get_db), + _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询themoviedb所有季信息 """ - seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid) + seasons_info = TmdbChain(db).tmdb_seasons(tmdbid=tmdbid) if not seasons_info: return [] else: @@ -26,15 +29,16 @@ def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) - @router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo]) def tmdb_similar(tmdbid: int, type_name: str, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询类似电影/电视剧,type_name: 电影/电视剧 """ mediatype = MediaType(type_name) if mediatype == MediaType.MOVIE: - tmdbinfos = TmdbChain().movie_similar(tmdbid=tmdbid) + tmdbinfos = TmdbChain(db).movie_similar(tmdbid=tmdbid) elif mediatype == MediaType.TV: - tmdbinfos = TmdbChain().tv_similar(tmdbid=tmdbid) + tmdbinfos = TmdbChain(db).tv_similar(tmdbid=tmdbid) else: return [] if not tmdbinfos: @@ -46,15 +50,16 @@ def tmdb_similar(tmdbid: int, @router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo]) def tmdb_recommend(tmdbid: int, type_name: str, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询推荐电影/电视剧,type_name: 电影/电视剧 """ mediatype = MediaType(type_name) if mediatype == MediaType.MOVIE: - tmdbinfos = TmdbChain().movie_recommend(tmdbid=tmdbid) + tmdbinfos = TmdbChain(db).movie_recommend(tmdbid=tmdbid) elif mediatype == MediaType.TV: - tmdbinfos = TmdbChain().tv_recommend(tmdbid=tmdbid) + tmdbinfos = TmdbChain(db).tv_recommend(tmdbid=tmdbid) else: return [] if not tmdbinfos: @@ -67,15 +72,16 @@ def tmdb_recommend(tmdbid: int, def tmdb_credits(tmdbid: int, type_name: str, page: int = 1, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询演员阵容,type_name: 电影/电视剧 """ mediatype = MediaType(type_name) if mediatype == MediaType.MOVIE: - tmdbinfos = TmdbChain().movie_credits(tmdbid=tmdbid, page=page) + tmdbinfos = TmdbChain(db).movie_credits(tmdbid=tmdbid, page=page) elif mediatype == MediaType.TV: - tmdbinfos = TmdbChain().tv_credits(tmdbid=tmdbid, page=page) + tmdbinfos = TmdbChain(db).tv_credits(tmdbid=tmdbid, page=page) else: return [] if not tmdbinfos: @@ -86,11 +92,12 @@ def tmdb_credits(tmdbid: int, @router.get("/person/{person_id}", summary="人物详情", response_model=schemas.TmdbPerson) def tmdb_person(person_id: int, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据人物ID查询人物详情 """ - tmdbinfo = TmdbChain().person_detail(person_id=person_id) + tmdbinfo = TmdbChain(db).person_detail(person_id=person_id) if not tmdbinfo: return schemas.TmdbPerson() else: @@ -100,11 +107,12 @@ def tmdb_person(person_id: int, @router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo]) def tmdb_person_credits(person_id: int, page: int = 1, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据人物ID查询人物参演作品 """ - tmdbinfo = TmdbChain().person_credits(person_id=person_id, page=page) + tmdbinfo = TmdbChain(db).person_credits(person_id=person_id, page=page) if not tmdbinfo: return [] else: @@ -116,15 +124,16 @@ def tmdb_movies(sort_by: str = "popularity.desc", with_genres: str = "", with_original_language: str = "", page: int = 1, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览TMDB电影信息 """ - movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE, - sort_by=sort_by, - with_genres=with_genres, - with_original_language=with_original_language, - page=page) + movies = TmdbChain(db).tmdb_discover(mtype=MediaType.MOVIE, + sort_by=sort_by, + with_genres=with_genres, + with_original_language=with_original_language, + page=page) if not movies: return [] return [MediaInfo(tmdb_info=movie).to_dict() for movie in movies] @@ -135,15 +144,16 @@ def tmdb_tvs(sort_by: str = "popularity.desc", with_genres: str = "", with_original_language: str = "", page: int = 1, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览TMDB剧集信息 """ - tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV, - sort_by=sort_by, - with_genres=with_genres, - with_original_language=with_original_language, - page=page) + tvs = TmdbChain(db).tmdb_discover(mtype=MediaType.TV, + sort_by=sort_by, + with_genres=with_genres, + with_original_language=with_original_language, + page=page) if not tvs: return [] return [MediaInfo(tmdb_info=tv).to_dict() for tv in tvs] @@ -151,11 +161,12 @@ def tmdb_tvs(sort_by: str = "popularity.desc", @router.get("/trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo]) def tmdb_trending(page: int = 1, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 浏览TMDB剧集信息 """ - infos = TmdbChain().tmdb_trending(page=page) + infos = TmdbChain(db).tmdb_trending(page=page) if not infos: return [] return [MediaInfo(tmdb_info=info).to_dict() for info in infos] @@ -163,11 +174,12 @@ def tmdb_trending(page: int = 1, @router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode]) def tmdb_season_episodes(tmdbid: int, season: int, + db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据TMDBID查询某季的所有信信息 """ - episodes_info = TmdbChain().tmdb_episodes(tmdbid=tmdbid, season=season) + episodes_info = TmdbChain(db).tmdb_episodes(tmdbid=tmdbid, season=season) if not episodes_info: return [] else: diff --git a/app/api/endpoints/webhook.py b/app/api/endpoints/webhook.py index a5379aba..cd6daf42 100644 --- a/app/api/endpoints/webhook.py +++ b/app/api/endpoints/webhook.py @@ -1,24 +1,27 @@ from typing import Any -from fastapi import APIRouter, BackgroundTasks, Request +from fastapi import APIRouter, BackgroundTasks, Request, Depends +from sqlalchemy.orm import Session from app import schemas from app.chain.webhook import WebhookChain from app.core.config import settings +from app.db import get_db router = APIRouter() -def start_webhook_chain(body: Any, form: Any, args: Any): +def start_webhook_chain(db: Session, body: Any, form: Any, args: Any): """ 启动链式任务 """ - WebhookChain().message(body=body, form=form, args=args) + WebhookChain(db).message(body=body, form=form, args=args) @router.post("/", summary="Webhook消息响应", response_model=schemas.Response) async def webhook_message(background_tasks: BackgroundTasks, - token: str, request: Request) -> Any: + token: str, request: Request, + db: Session = Depends(get_db),) -> Any: """ Webhook响应 """ @@ -27,18 +30,19 @@ async def webhook_message(background_tasks: BackgroundTasks, body = await request.body() form = await request.form() args = request.query_params - background_tasks.add_task(start_webhook_chain, body, form, args) + background_tasks.add_task(start_webhook_chain, db, body, form, args) return schemas.Response(success=True) @router.get("/", summary="Webhook消息响应", response_model=schemas.Response) async def webhook_message(background_tasks: BackgroundTasks, - token: str, request: Request) -> Any: + token: str, request: Request, + db: Session = Depends(get_db)) -> Any: """ Webhook响应 """ if token != settings.API_TOKEN: return schemas.Response(success=False, message="token认证不通过") args = request.query_params - background_tasks.add_task(start_webhook_chain, None, None, args) + background_tasks.add_task(start_webhook_chain, db, None, None, args) return schemas.Response(success=True) diff --git a/app/api/servarr.py b/app/api/servarr.py index 1c651be5..0b819377 100644 --- a/app/api/servarr.py +++ b/app/api/servarr.py @@ -235,11 +235,11 @@ def arr_movie_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> A ) tmdbid = term.replace("tmdb:", "") # 查询媒体信息 - mediainfo = MediaChain().recognize_media(mtype=MediaType.MOVIE, tmdbid=int(tmdbid)) + mediainfo = MediaChain(db).recognize_media(mtype=MediaType.MOVIE, tmdbid=int(tmdbid)) if not mediainfo: return [RadarrMovie()] # 查询是否已存在 - exists = MediaChain().media_exists(mediainfo=mediainfo) + exists = MediaChain(db).media_exists(mediainfo=mediainfo) if not exists: # 文件不存在 hasfile = False @@ -324,11 +324,11 @@ def arr_add_movie(apikey: str, "id": subscribe.id } # 添加订阅 - sid, message = SubscribeChain().add(title=movie.title, - year=movie.year, - mtype=MediaType.MOVIE, - tmdbid=movie.tmdbId, - userid="Seerr") + sid, message = SubscribeChain(db).add(title=movie.title, + year=movie.year, + mtype=MediaType.MOVIE, + tmdbid=movie.tmdbId, + userid="Seerr") if sid: return { "id": sid @@ -515,8 +515,8 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> # 获取TVDBID if not term.startswith("tvdb:"): - mediainfo = MediaChain().recognize_media(meta=MetaInfo(term), - mtype=MediaType.TV) + mediainfo = MediaChain(db).recognize_media(meta=MetaInfo(term), + mtype=MediaType.TV) if not mediainfo: return [SonarrSeries()] tvdbid = mediainfo.tvdb_id @@ -527,7 +527,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> tvdbid = int(term.replace("tvdb:", "")) # 查询TVDB信息 - tvdbinfo = MediaChain().tvdb_info(tvdbid=tvdbid) + tvdbinfo = MediaChain(db).tvdb_info(tvdbid=tvdbid) if not tvdbinfo: return [SonarrSeries()] @@ -539,11 +539,11 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> # 根据TVDB查询媒体信息 if not mediainfo: - mediainfo = MediaChain().recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')), - mtype=MediaType.TV) + mediainfo = MediaChain(db).recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')), + mtype=MediaType.TV) # 查询是否存在 - exists = MediaChain().media_exists(mediainfo) + exists = MediaChain(db).media_exists(mediainfo) if exists: hasfile = True else: @@ -666,12 +666,12 @@ def arr_add_series(apikey: str, tv: schemas.SonarrSeries, for season in left_seasons: if not season.get("monitored"): continue - sid, message = SubscribeChain().add(title=tv.title, - year=tv.year, - season=season.get("seasonNumber"), - tmdbid=tv.tmdbId, - mtype=MediaType.TV, - userid="Seerr") + sid, message = SubscribeChain(db).add(title=tv.title, + year=tv.year, + season=season.get("seasonNumber"), + tmdbid=tv.tmdbId, + mtype=MediaType.TV, + userid="Seerr") if sid: return { diff --git a/app/chain/__init__.py b/app/chain/__init__.py index 9e14230a..659a5979 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -7,6 +7,7 @@ from typing import Optional, Any, Tuple, List, Set, Union, Dict from qbittorrentapi import TorrentFilesList from ruamel.yaml import CommentedMap +from sqlalchemy.orm import Session from transmission_rpc import File from app.core.config import settings @@ -27,10 +28,11 @@ class ChainBase(metaclass=ABCMeta): 处理链基类 """ - def __init__(self): + def __init__(self, db: Session = None): """ 公共初始化 """ + self._db = db self.modulemanager = ModuleManager() self.eventmanager = EventManager() diff --git a/app/chain/cookiecloud.py b/app/chain/cookiecloud.py index 8fc7f7fe..704f46c1 100644 --- a/app/chain/cookiecloud.py +++ b/app/chain/cookiecloud.py @@ -3,6 +3,7 @@ from typing import Tuple, Optional, Union from urllib.parse import urljoin from lxml import etree +from sqlalchemy.orm import Session from app.chain import ChainBase from app.chain.site import SiteChain @@ -22,12 +23,12 @@ class CookieCloudChain(ChainBase): CookieCloud处理链 """ - def __init__(self): - super().__init__() - self.siteoper = SiteOper() - self.siteiconoper = SiteIconOper() + def __init__(self, db: Session = None): + super().__init__(db) + self.siteoper = SiteOper(self._db) + self.siteiconoper = SiteIconOper(self._db) self.siteshelper = SitesHelper() - self.sitechain = SiteChain() + self.sitechain = SiteChain(self._db) self.message = MessageHelper() self.cookiecloud = CookieCloudHelper( server=settings.COOKIECLOUD_HOST, diff --git a/app/chain/download.py b/app/chain/download.py index 417998c1..ecd27ca9 100644 --- a/app/chain/download.py +++ b/app/chain/download.py @@ -2,6 +2,8 @@ import re from pathlib import Path from typing import List, Optional, Tuple, Set, Dict, Union +from sqlalchemy.orm import Session + from app.chain import ChainBase from app.core.config import settings from app.core.context import MediaInfo, TorrentInfo, Context @@ -20,11 +22,11 @@ class DownloadChain(ChainBase): 下载处理链 """ - def __init__(self): - super().__init__() + def __init__(self, db: Session = None): + super().__init__(db) self.torrent = TorrentHelper() - self.downloadhis = DownloadHistoryOper() - self.mediaserver = MediaServerOper() + self.downloadhis = DownloadHistoryOper(self._db) + self.mediaserver = MediaServerOper(self._db) def post_download_message(self, meta: MetaBase, mediainfo: MediaInfo, torrent: TorrentInfo, channel: MessageChannel = None, diff --git a/app/chain/mediaserver.py b/app/chain/mediaserver.py index fe6bf00e..2df894e8 100644 --- a/app/chain/mediaserver.py +++ b/app/chain/mediaserver.py @@ -2,6 +2,8 @@ import json import threading from typing import List, Union, Generator +from sqlalchemy.orm import Session + from app import schemas from app.chain import ChainBase from app.core.config import settings @@ -17,9 +19,9 @@ class MediaServerChain(ChainBase): 媒体服务器处理链 """ - def __init__(self): - super().__init__() - self.mediaserverdb = MediaServerOper() + def __init__(self, db: Session = None): + super().__init__(db) + self.mediaserverdb = MediaServerOper(db) def librarys(self) -> List[schemas.MediaServerLibrary]: """ diff --git a/app/chain/message.py b/app/chain/message.py index 836ebe1a..07d85faa 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -27,12 +27,12 @@ class MessageChain(ChainBase): # 每页数据量 _page_size: int = 8 - def __init__(self): - super().__init__() - self.downloadchain = DownloadChain() - self.subscribechain = SubscribeChain() - self.searchchain = SearchChain() - self.medtachain = MediaChain() + def __init__(self, db: Session = None): + super().__init__(db) + self.downloadchain = DownloadChain(self._db) + self.subscribechain = SubscribeChain(self._db) + self.searchchain = SearchChain(self._db) + self.medtachain = MediaChain(self._db) self.torrent = TorrentHelper() self.eventmanager = EventManager() self.torrenthelper = TorrentHelper() diff --git a/app/chain/rss.py b/app/chain/rss.py index a8470e8e..45ab96d9 100644 --- a/app/chain/rss.py +++ b/app/chain/rss.py @@ -4,6 +4,8 @@ import time from datetime import datetime from typing import Tuple, Optional +from sqlalchemy.orm import Session + from app.chain import ChainBase from app.chain.download import DownloadChain from app.core.config import settings @@ -25,12 +27,12 @@ class RssChain(ChainBase): RSS处理链 """ - def __init__(self): - super().__init__() - self.rssoper = RssOper() + def __init__(self, db: Session = None): + super().__init__(db) + self.rssoper = RssOper(self._db) self.sites = SitesHelper() - self.systemconfig = SystemConfigOper() - self.downloadchain = DownloadChain() + self.systemconfig = SystemConfigOper(self._db) + self.downloadchain = DownloadChain(self._db) self.message = MessageHelper() def add(self, title: str, year: str, diff --git a/app/chain/search.py b/app/chain/search.py index 0a49072f..b1f7c249 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -4,6 +4,8 @@ from datetime import datetime from typing import Dict from typing import List, Optional +from sqlalchemy.orm import Session + from app.chain import ChainBase from app.core.context import Context from app.core.context import MediaInfo, TorrentInfo @@ -23,11 +25,11 @@ class SearchChain(ChainBase): 站点资源搜索处理链 """ - def __init__(self): - super().__init__() + def __init__(self, db: Session = None): + super().__init__(db) self.siteshelper = SitesHelper() self.progress = ProgressHelper() - self.systemconfig = SystemConfigOper() + self.systemconfig = SystemConfigOper(self._db) self.torrenthelper = TorrentHelper() def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None) -> List[Context]: diff --git a/app/chain/site.py b/app/chain/site.py index 1fcdc37a..fac91384 100644 --- a/app/chain/site.py +++ b/app/chain/site.py @@ -1,5 +1,7 @@ from typing import Union, Tuple +from sqlalchemy.orm import Session + from app.chain import ChainBase from app.core.config import settings from app.db.models.site import Site @@ -20,9 +22,9 @@ class SiteChain(ChainBase): 站点管理处理链 """ - def __init__(self): - super().__init__() - self.siteoper = SiteOper() + def __init__(self, db: Session = None): + super().__init__(db) + self.siteoper = SiteOper(self._db) self.cookiehelper = CookieHelper() self.message = MessageHelper() diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 05058526..3a1bc893 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -3,6 +3,8 @@ import re from datetime import datetime from typing import Dict, List, Optional, Union, Tuple +from requests import Session + from app.chain import ChainBase from app.chain.download import DownloadChain from app.chain.search import SearchChain @@ -27,14 +29,14 @@ class SubscribeChain(ChainBase): _cache_file = "__torrents_cache__" - def __init__(self): - super().__init__() - self.downloadchain = DownloadChain() - self.searchchain = SearchChain() - self.subscribehelper = SubscribeOper() + def __init__(self, db: Session = None): + super().__init__(db) + self.downloadchain = DownloadChain(self._db) + self.searchchain = SearchChain(self._db) + self.subscribehelper = SubscribeOper(self._db) self.siteshelper = SitesHelper() self.message = MessageHelper() - self.systemconfig = SystemConfigOper() + self.systemconfig = SystemConfigOper(self._db) def add(self, title: str, year: str, mtype: MediaType = None, diff --git a/app/chain/transfer.py b/app/chain/transfer.py index 29f50a03..ff2e5d3d 100644 --- a/app/chain/transfer.py +++ b/app/chain/transfer.py @@ -5,6 +5,8 @@ import threading from pathlib import Path from typing import List, Optional, Union +from sqlalchemy.orm import Session + from app.chain import ChainBase from app.core.config import settings from app.core.context import MediaInfo @@ -28,10 +30,10 @@ class TransferChain(ChainBase): 文件转移处理链 """ - def __init__(self): - super().__init__() - self.downloadhis = DownloadHistoryOper() - self.transferhis = TransferHistoryOper() + def __init__(self, db: Session = None): + super().__init__(db) + self.downloadhis = DownloadHistoryOper(self._db) + self.transferhis = TransferHistoryOper(self._db) self.progress = ProgressHelper() def process(self, arg_str: str = None, channel: MessageChannel = None, userid: Union[str, int] = None) -> bool: diff --git a/app/db/__init__.py b/app/db/__init__.py index 44b0c613..e1367cde 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -33,8 +33,11 @@ class DbOper: _db: Session = None - def __init__(self, _db=SessionLocal()): - self._db = _db + def __init__(self, db: Session = None): + if db: + self._db = db + else: + self._db = SessionLocal() def __del__(self): if self._db: diff --git a/app/db/downloadhistory_oper.py b/app/db/downloadhistory_oper.py index da76f7c8..ee1e6e80 100644 --- a/app/db/downloadhistory_oper.py +++ b/app/db/downloadhistory_oper.py @@ -15,7 +15,7 @@ class DownloadHistoryOper(DbOper): 按路径查询下载记录 :param path: 数据key """ - return DownloadHistory.get_by_path(self._db, path) + return DownloadHistory.get_by_path(self._db, str(path)) def get_by_hash(self, download_hash: str) -> Any: """ diff --git a/app/db/mediaserver_oper.py b/app/db/mediaserver_oper.py index b6b5e305..6eb8764c 100644 --- a/app/db/mediaserver_oper.py +++ b/app/db/mediaserver_oper.py @@ -1,7 +1,9 @@ import json from typing import Optional -from app.db import DbOper, SessionLocal +from sqlalchemy.orm import Session + +from app.db import DbOper from app.db.models.mediaserver import MediaServerItem @@ -10,7 +12,7 @@ class MediaServerOper(DbOper): 媒体服务器数据管理 """ - def __init__(self, db=SessionLocal()): + def __init__(self, db: Session = None): super().__init__(db) def add(self, **kwargs) -> bool: diff --git a/app/db/rss_oper.py b/app/db/rss_oper.py index 88abe070..2e649496 100644 --- a/app/db/rss_oper.py +++ b/app/db/rss_oper.py @@ -1,6 +1,8 @@ from typing import List -from app.db import DbOper, SessionLocal +from sqlalchemy.orm import Session + +from app.db import DbOper from app.db.models.rss import Rss @@ -9,7 +11,7 @@ class RssOper(DbOper): RSS订阅数据管理 """ - def __init__(self, db=SessionLocal()): + def __init__(self, db: Session = None): super().__init__(db) def add(self, **kwargs) -> bool: diff --git a/app/db/systemconfig_oper.py b/app/db/systemconfig_oper.py index 5cafdfdb..12ccdd3c 100644 --- a/app/db/systemconfig_oper.py +++ b/app/db/systemconfig_oper.py @@ -1,22 +1,24 @@ import json from typing import Any, Union -from app.db import DbOper, SessionLocal +from sqlalchemy.orm import Session + +from app.db import DbOper from app.db.models.systemconfig import SystemConfig +from app.schemas.types import SystemConfigKey from app.utils.object import ObjectUtils from app.utils.singleton import Singleton -from app.schemas.types import SystemConfigKey class SystemConfigOper(DbOper, metaclass=Singleton): # 配置对象 __SYSTEMCONF: dict = {} - def __init__(self, _db=SessionLocal()): + def __init__(self, db: Session = None): """ 加载配置到内存 """ - super().__init__(_db) + super().__init__(db) for item in SystemConfig.list(self._db): if ObjectUtils.is_obj(item.value): self.__SYSTEMCONF[item.key] = json.loads(item.value)