fix 数据库会话管理

This commit is contained in:
jxxghp 2023-08-16 10:22:45 +08:00
parent b086bbf015
commit da93328d50
29 changed files with 255 additions and 172 deletions

View File

@ -19,11 +19,12 @@ router = APIRouter()
@router.get("/statistic", summary="媒体数量统计", response_model=schemas.Statistic)
def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def statistic(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询媒体数量统计信息
"""
media_statistic = DashboardChain().media_statistic()
media_statistic = DashboardChain(db).media_statistic()
if media_statistic:
return schemas.Statistic(
movie_count=media_statistic.movie_count,
@ -61,11 +62,12 @@ def processes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/downloader", summary="下载器信息", response_model=schemas.DownloaderInfo)
def downloader(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def downloader(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询下载器信息
"""
transfer_info = DashboardChain().downloader_info()
transfer_info = DashboardChain(db).downloader_info()
free_space = SystemUtils.free_space(Path(settings.DOWNLOAD_PATH))
return schemas.DownloaderInfo(
download_speed=transfer_info.download_speed,

View File

@ -1,12 +1,14 @@
from typing import List, Any
from fastapi import APIRouter, Depends, Response
from sqlalchemy.orm import Session
from app import schemas
from app.chain.douban import DoubanChain
from app.core.config import settings
from app.core.context import MediaInfo
from app.core.security import verify_token
from app.db import get_db
from app.schemas import MediaType
from app.utils.http import RequestUtils
@ -30,12 +32,13 @@ def douban_img(imgurl: str) -> Any:
@router.get("/recognize/{doubanid}", summary="豆瓣ID识别", response_model=schemas.Context)
def recognize_doubanid(doubanid: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID识别媒体信息
"""
# 识别媒体信息
context = DoubanChain().recognize_by_doubanid(doubanid=doubanid)
context = DoubanChain(db).recognize_by_doubanid(doubanid=doubanid)
if context:
return context.to_dict()
else:
@ -47,11 +50,12 @@ def douban_movies(sort: str = "R",
tags: str = "",
page: int = 1,
count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣电影信息
"""
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE,
movies = DoubanChain(db).douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
if not movies:
return []
@ -67,11 +71,12 @@ def douban_tvs(sort: str = "R",
tags: str = "",
page: int = 1,
count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
tvs = DoubanChain().douban_discover(mtype=MediaType.TV,
tvs = DoubanChain(db).douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
if not tvs:
return []
@ -86,42 +91,47 @@ def douban_tvs(sort: str = "R",
@router.get("/movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
def movie_top250(page: int = 1,
count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
movies = DoubanChain().movie_top250(page=page, count=count)
movies = DoubanChain(db).movie_top250(page=page, count=count)
return [MediaInfo(douban_info=movie).to_dict() for movie in movies]
@router.get("/tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
def tv_weekly_chinese(page: int = 1,
count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
tvs = DoubanChain().tv_weekly_chinese(page=page, count=count)
tvs = DoubanChain(db).tv_weekly_chinese(page=page, count=count)
return [MediaInfo(douban_info=tv).to_dict() for tv in tvs]
@router.get("/tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
def tv_weekly_global(page: int = 1,
count: int = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
tvs = DoubanChain().tv_weekly_global(page=page, count=count)
tvs = DoubanChain(db).tv_weekly_global(page=page, count=count)
return [MediaInfo(douban_info=tv).to_dict() for tv in tvs]
@router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo)
def douban_info(doubanid: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
def douban_info(doubanid: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID查询豆瓣媒体信息
"""
doubaninfo = DoubanChain().douban_info(doubanid=doubanid)
doubaninfo = DoubanChain(db).douban_info(doubanid=doubanid)
if doubaninfo:
return MediaInfo(douban_info=doubaninfo).to_dict()
else:

View File

@ -1,6 +1,7 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app import schemas
from app.chain.douban import DoubanChain
@ -9,6 +10,7 @@ from app.chain.media import MediaChain
from app.core.context import MediaInfo, Context, TorrentInfo
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.db import get_db
from app.db.models.user import User
from app.db.userauth import get_current_active_superuser
from app.schemas import NotExistMediaInfo, MediaType
@ -18,18 +20,20 @@ router = APIRouter()
@router.get("/", summary="正在下载", response_model=List[schemas.DownloadingTorrent])
def read_downloading(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询正在下载的任务
"""
return DownloadChain().downloading()
return DownloadChain(db).downloading()
@router.post("/", summary="添加下载", response_model=schemas.Response)
def add_downloading(
media_in: schemas.MediaInfo,
torrent_in: schemas.TorrentInfo,
current_user: User = Depends(get_current_active_superuser)) -> Any:
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
添加下载任务
"""
@ -47,7 +51,7 @@ def add_downloading(
media_info=mediainfo,
torrent_info=torrentinfo
)
did = DownloadChain().download_single(context=context)
did = DownloadChain(db).download_single(context=context)
return schemas.Response(success=True if did else False, data={
"download_id": did
})
@ -55,6 +59,7 @@ def add_downloading(
@router.post("/notexists", summary="查询缺失媒体信息", response_model=List[NotExistMediaInfo])
def exists(media_in: schemas.MediaInfo,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询缺失媒体信息
@ -65,19 +70,19 @@ def exists(media_in: schemas.MediaInfo,
if media_in.tmdb_id:
mediainfo.from_dict(media_in.dict())
elif media_in.douban_id:
context = DoubanChain().recognize_by_doubanid(doubanid=media_in.douban_id)
context = DoubanChain(db).recognize_by_doubanid(doubanid=media_in.douban_id)
if context:
mediainfo = context.media_info
meta = context.meta_info
else:
context = MediaChain().recognize_by_title(title=f"{media_in.title} {media_in.year}")
context = MediaChain(db).recognize_by_title(title=f"{media_in.title} {media_in.year}")
if context:
mediainfo = context.media_info
meta = context.meta_info
# 查询缺失信息
if not mediainfo or not mediainfo.tmdb_id:
raise HTTPException(status_code=404, detail="媒体信息不存在")
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
exist_flag, no_exists = DownloadChain(db).get_no_exists_info(meta=meta, mediainfo=mediainfo)
if mediainfo.type == MediaType.MOVIE:
# 电影已存在时返回空列表,存在时返回空对像列表
return [] if exist_flag else [NotExistMediaInfo()]
@ -90,31 +95,34 @@ def exists(media_in: schemas.MediaInfo,
@router.put("/{hashString}/start", summary="开始任务", response_model=schemas.Response)
def start_downloading(
hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
开如下载任务
"""
ret = DownloadChain().set_downloading(hashString, "start")
ret = DownloadChain(db).set_downloading(hashString, "start")
return schemas.Response(success=True if ret else False)
@router.put("/{hashString}/stop", summary="暂停任务", response_model=schemas.Response)
def stop_downloading(
hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
控制下载任务
"""
ret = DownloadChain().set_downloading(hashString, "stop")
ret = DownloadChain(db).set_downloading(hashString, "stop")
return schemas.Response(success=True if ret else False)
@router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response)
def remove_downloading(
hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
控制下载任务
"""
ret = DownloadChain().remove_downloading(hashString)
ret = DownloadChain(db).remove_downloading(hashString)
return schemas.Response(success=True if ret else False)

View File

@ -74,7 +74,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
if not history:
return schemas.Response(success=False, msg="记录不存在")
# 册除文件
TransferChain().delete_files(Path(history.dest))
TransferChain(db).delete_files(Path(history.dest))
# 删除记录
TransferHistory.delete(db, history_in.id)
return schemas.Response(success=True)
@ -84,12 +84,13 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
def redo_transfer_history(history_in: schemas.TransferHistory,
mtype: str,
new_tmdbid: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
历史记录重新转移
"""
hash_str = history_in.download_hash
result = TransferChain().process(f"{hash_str} {new_tmdbid}|{mtype}")
result = TransferChain(db).process(f"{hash_str} {new_tmdbid}|{mtype}")
if result:
return schemas.Response(success=True)
else:

View File

@ -36,7 +36,7 @@ async def login_access_token(
if not user:
# 请求协助认证
logger.warn("登录用户本地不匹配,尝试辅助认证 ...")
token = UserChain().user_authenticate(form_data.username, form_data.password)
token = UserChain(db).user_authenticate(form_data.username, form_data.password)
if not token:
raise HTTPException(status_code=401, detail="用户名或密码不正确")
else:
@ -83,11 +83,11 @@ def bing_wallpaper() -> Any:
@router.get("/tmdb", summary="TMDB电影海报", response_model=schemas.Response)
def tmdb_wallpaper() -> Any:
def tmdb_wallpaper(db: Session = Depends(get_db)) -> Any:
"""
获取TMDB电影海报
"""
infos = TmdbChain().tmdb_trending()
infos = TmdbChain(db).tmdb_trending()
if infos:
# 随机一个电影
while True:

View File

@ -20,12 +20,13 @@ router = APIRouter()
@router.get("/recognize", summary="识别媒体信息", response_model=schemas.Context)
def recognize(title: str,
subtitle: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据标题副标题识别媒体信息
"""
# 识别媒体信息
context = MediaChain().recognize_by_title(title=title, subtitle=subtitle)
context = MediaChain(db).recognize_by_title(title=title, subtitle=subtitle)
if context:
return context.to_dict()
return schemas.Context()
@ -35,11 +36,12 @@ def recognize(title: str,
def search_by_title(title: str,
page: int = 1,
count: int = 8,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
模糊搜索媒体信息列表
"""
_, medias = MediaChain().search(title=title)
_, medias = MediaChain(db).search(title=title)
if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return []
@ -69,20 +71,21 @@ def exists(title: str = None,
@router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo)
def tmdb_info(mediaid: str, type_name: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据媒体ID查询themoviedb或豆瓣媒体信息type_name: 电影/电视剧
"""
mtype = MediaType(type_name)
if mediaid.startswith("tmdb:"):
result = TmdbChain().tmdb_info(int(mediaid[5:]), mtype)
result = TmdbChain(db).tmdb_info(int(mediaid[5:]), mtype)
return MediaInfo(tmdb_info=result).to_dict()
elif mediaid.startswith("douban:"):
# 查询豆瓣信息
doubaninfo = DoubanChain().douban_info(doubanid=mediaid[7:])
doubaninfo = DoubanChain(db).douban_info(doubanid=mediaid[7:])
if not doubaninfo:
return schemas.MediaInfo()
result = DoubanChain().recognize_by_doubaninfo(doubaninfo)
result = DoubanChain(db).recognize_by_doubaninfo(doubaninfo)
if result:
# TMDB
return result.media_info.to_dict()

View File

@ -19,22 +19,23 @@ from app.schemas.types import SystemConfigKey, NotificationType
router = APIRouter()
def start_message_chain(body: Any, form: Any, args: Any):
def start_message_chain(db: Session, body: Any, form: Any, args: Any):
"""
启动链式任务
"""
MessageChain().process(body=body, form=form, args=args)
MessageChain(db).process(body=body, form=form, args=args)
@router.post("/", summary="接收用户消息", response_model=schemas.Response)
async def user_message(background_tasks: BackgroundTasks, request: Request):
async def user_message(background_tasks: BackgroundTasks, request: Request,
db: Session = Depends(get_db)):
"""
用户消息响应
"""
body = await request.body()
form = await request.form()
args = request.query_params
background_tasks.add_task(start_message_chain, body, form, args)
background_tasks.add_task(start_message_chain, db, body, form, args)
return schemas.Response(success=True)

View File

@ -15,11 +15,11 @@ from app.schemas import MediaType
router = APIRouter()
def start_rss_refresh(rssid: int = None):
def start_rss_refresh(db: Session, rssid: int = None):
"""
启动自定义订阅刷新
"""
RssChain().refresh(rssid=rssid, manual=True)
RssChain(db).refresh(rssid=rssid, manual=True)
@router.get("/", summary="所有自定义订阅", response_model=List[schemas.Rss])
@ -36,6 +36,7 @@ def read_rsses(
def create_rss(
*,
rss_in: schemas.Rss,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
@ -45,7 +46,7 @@ def create_rss(
mtype = MediaType(rss_in.type)
else:
mtype = None
rssid, errormsg = RssChain().add(
rssid, errormsg = RssChain(db).add(
mtype=mtype,
**rss_in.dict()
)
@ -100,11 +101,13 @@ def preview_rss(
def refresh_rss(
rssid: int,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据ID刷新自定义订阅
"""
background_tasks.add_task(start_rss_refresh,
db=db,
rssid=rssid)
return schemas.Response(success=True)

View File

@ -1,28 +1,32 @@
from typing import List, Any
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app import schemas
from app.chain.douban import DoubanChain
from app.chain.search import SearchChain
from app.core.security import verify_token
from app.db import get_db
from app.schemas.types import MediaType
router = APIRouter()
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search_latest(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询搜索结果
"""
torrents = SearchChain().last_search_results()
torrents = SearchChain(db).last_search_results()
return [torrent.to_dict() for torrent in torrents]
@router.get("/media/{mediaid}", summary="精确搜索资源", response_model=List[schemas.Context])
def search_by_tmdbid(mediaid: str,
mtype: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/
@ -31,14 +35,14 @@ def search_by_tmdbid(mediaid: str,
tmdbid = int(mediaid.replace("tmdb:", ""))
if mtype:
mtype = MediaType(mtype)
torrents = SearchChain().search_by_tmdbid(tmdbid=tmdbid, mtype=mtype)
torrents = SearchChain(db).search_by_tmdbid(tmdbid=tmdbid, mtype=mtype)
elif mediaid.startswith("douban:"):
doubanid = mediaid.replace("douban:", "")
# 识别豆瓣信息
context = DoubanChain().recognize_by_doubanid(doubanid)
context = DoubanChain(db).recognize_by_doubanid(doubanid)
if not context or not context.media_info or not context.media_info.tmdb_id:
raise HTTPException(status_code=404, detail="无法识别TMDB媒体信息")
torrents = SearchChain().search_by_tmdbid(tmdbid=context.media_info.tmdb_id,
torrents = SearchChain(db).search_by_tmdbid(tmdbid=context.media_info.tmdb_id,
mtype=context.media_info.type)
else:
return []
@ -49,9 +53,10 @@ def search_by_tmdbid(mediaid: str,
async def search_by_title(keyword: str = None,
page: int = 0,
site: int = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据名称模糊搜索站点资源支持分页关键词为空是返回首页资源
"""
torrents = SearchChain().search_by_title(title=keyword, page=page, site=site)
torrents = SearchChain(db).search_by_title(title=keyword, page=page, site=site)
return [torrent.to_dict() for torrent in torrents]

View File

@ -19,11 +19,11 @@ from app.utils.string import StringUtils
router = APIRouter()
def start_cookiecloud_sync():
def start_cookiecloud_sync(db: Session):
"""
后台启动CookieCloud站点同步
"""
CookieCloudChain().process(manual=True)
CookieCloudChain(db).process(manual=True)
@router.get("/", summary="所有站点", response_model=List[schemas.Site])
@ -67,11 +67,12 @@ def delete_site(
@router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response)
def cookie_cloud_sync(background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
运行CookieCloud同步站点信息
"""
background_tasks.add_task(start_cookiecloud_sync)
background_tasks.add_task(start_cookiecloud_sync, db)
return schemas.Response(success=True, message="CookieCloud同步任务已启动")
@ -83,7 +84,7 @@ def cookie_cloud_sync(db: Session = Depends(get_db),
"""
Site.reset(db)
SystemConfigOper(db).set(SystemConfigKey.IndexerSites, [])
CookieCloudChain().process(manual=True)
CookieCloudChain(db).process(manual=True)
return schemas.Response(success=True, message="站点已重置!")
@ -105,7 +106,7 @@ def update_cookie(
detail=f"站点 {site_id} 不存在!",
)
# 更新Cookie
state, message = SiteChain().update_cookie(site_info=site_info,
state, message = SiteChain(db).update_cookie(site_info=site_info,
username=username,
password=password)
return schemas.Response(success=state, message=message)
@ -124,7 +125,7 @@ def test_site(site_id: int,
status_code=404,
detail=f"站点 {site_id} 不存在",
)
status, message = SiteChain().test(site.domain)
status, message = SiteChain(db).test(site.domain)
return schemas.Response(success=status, message=message)
@ -162,7 +163,7 @@ def site_resource(site_id: int, keyword: str = None,
status_code=404,
detail=f"站点 {site_id} 不存在",
)
torrents = SearchChain().browse(site.domain, keyword)
torrents = SearchChain(db).browse(site.domain, keyword)
if not torrents:
return []
return [torrent.to_dict() for torrent in torrents]

View File

@ -17,20 +17,20 @@ from app.schemas.types import MediaType
router = APIRouter()
def start_subscribe_add(title: str, year: str,
def start_subscribe_add(db: Session, title: str, year: str,
mtype: MediaType, tmdbid: int, season: int, username: str):
"""
启动订阅任务
"""
SubscribeChain().add(title=title, year=year,
SubscribeChain(db).add(title=title, year=year,
mtype=mtype, tmdbid=tmdbid, season=season, username=username)
def start_subscribe_search(sid: Optional[int], state: Optional[str]):
def start_subscribe_search(db: Session, sid: Optional[int], state: Optional[str]):
"""
启动订阅搜索任务
"""
SubscribeChain().search(sid=sid, state=state, manual=True)
SubscribeChain(db).search(sid=sid, state=state, manual=True)
@router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe])
@ -51,6 +51,7 @@ def read_subscribes(
def create_subscribe(
*,
subscribe_in: schemas.Subscribe,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
) -> Any:
"""
@ -66,7 +67,7 @@ def create_subscribe(
title = subscribe_in.name
else:
title = None
sid, message = SubscribeChain().add(mtype=mtype,
sid, message = SubscribeChain(db).add(mtype=mtype,
title=title,
year=subscribe_in.year,
tmdbid=subscribe_in.tmdbid,
@ -171,6 +172,7 @@ def delete_subscribe(
@router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response)
async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
authorization: str = Header(None)) -> Any:
"""
Jellyseerr/Overseerr订阅
@ -198,6 +200,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
# 添加订阅
if media_type == MediaType.MOVIE:
background_tasks.add_task(start_subscribe_add,
db=db,
mtype=media_type,
tmdbid=tmdbId,
title=subject,
@ -212,6 +215,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
break
for season in seasons:
background_tasks.add_task(start_subscribe_add,
db=db,
mtype=media_type,
tmdbid=tmdbId,
title=subject,
@ -224,11 +228,12 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
@router.get("/refresh", summary="刷新订阅", response_model=schemas.Response)
def refresh_subscribes(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
刷新所有订阅
"""
SubscribeChain().refresh()
SubscribeChain(db).refresh()
return schemas.Response(success=True)
@ -236,20 +241,22 @@ def refresh_subscribes(
def search_subscribe(
subscribe_id: int,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
搜索所有订阅
"""
background_tasks.add_task(start_subscribe_search, sid=subscribe_id, state=None)
background_tasks.add_task(start_subscribe_search, db=db, sid=subscribe_id, state=None)
return schemas.Response(success=True)
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
def search_subscribes(
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
搜索所有订阅
"""
background_tasks.add_task(start_subscribe_search, sid=None, state='R')
background_tasks.add_task(start_subscribe_search, db=db, sid=None, state='R')
return schemas.Response(success=True)

View File

@ -1,22 +1,25 @@
from typing import List, Any
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import schemas
from app.chain.tmdb import TmdbChain
from app.core.context import MediaInfo
from app.core.security import verify_token
from app.db import get_db
from app.schemas.types import MediaType
router = APIRouter()
@router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason])
def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
def tmdb_seasons(tmdbid: int, db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询themoviedb所有季信息
"""
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid)
seasons_info = TmdbChain(db).tmdb_seasons(tmdbid=tmdbid)
if not seasons_info:
return []
else:
@ -26,15 +29,16 @@ def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -
@router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_similar(tmdbid: int,
type_name: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询类似电影/电视剧type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
tmdbinfos = TmdbChain().movie_similar(tmdbid=tmdbid)
tmdbinfos = TmdbChain(db).movie_similar(tmdbid=tmdbid)
elif mediatype == MediaType.TV:
tmdbinfos = TmdbChain().tv_similar(tmdbid=tmdbid)
tmdbinfos = TmdbChain(db).tv_similar(tmdbid=tmdbid)
else:
return []
if not tmdbinfos:
@ -46,15 +50,16 @@ def tmdb_similar(tmdbid: int,
@router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_recommend(tmdbid: int,
type_name: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询推荐电影/电视剧type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
tmdbinfos = TmdbChain().movie_recommend(tmdbid=tmdbid)
tmdbinfos = TmdbChain(db).movie_recommend(tmdbid=tmdbid)
elif mediatype == MediaType.TV:
tmdbinfos = TmdbChain().tv_recommend(tmdbid=tmdbid)
tmdbinfos = TmdbChain(db).tv_recommend(tmdbid=tmdbid)
else:
return []
if not tmdbinfos:
@ -67,15 +72,16 @@ def tmdb_recommend(tmdbid: int,
def tmdb_credits(tmdbid: int,
type_name: str,
page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询演员阵容type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
tmdbinfos = TmdbChain().movie_credits(tmdbid=tmdbid, page=page)
tmdbinfos = TmdbChain(db).movie_credits(tmdbid=tmdbid, page=page)
elif mediatype == MediaType.TV:
tmdbinfos = TmdbChain().tv_credits(tmdbid=tmdbid, page=page)
tmdbinfos = TmdbChain(db).tv_credits(tmdbid=tmdbid, page=page)
else:
return []
if not tmdbinfos:
@ -86,11 +92,12 @@ def tmdb_credits(tmdbid: int,
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.TmdbPerson)
def tmdb_person(person_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物详情
"""
tmdbinfo = TmdbChain().person_detail(person_id=person_id)
tmdbinfo = TmdbChain(db).person_detail(person_id=person_id)
if not tmdbinfo:
return schemas.TmdbPerson()
else:
@ -100,11 +107,12 @@ def tmdb_person(person_id: int,
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def tmdb_person_credits(person_id: int,
page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物参演作品
"""
tmdbinfo = TmdbChain().person_credits(person_id=person_id, page=page)
tmdbinfo = TmdbChain(db).person_credits(person_id=person_id, page=page)
if not tmdbinfo:
return []
else:
@ -116,11 +124,12 @@ def tmdb_movies(sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB电影信息
"""
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE,
movies = TmdbChain(db).tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
@ -135,11 +144,12 @@ def tmdb_tvs(sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV,
tvs = TmdbChain(db).tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
@ -151,11 +161,12 @@ def tmdb_tvs(sort_by: str = "popularity.desc",
@router.get("/trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo])
def tmdb_trending(page: int = 1,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
infos = TmdbChain().tmdb_trending(page=page)
infos = TmdbChain(db).tmdb_trending(page=page)
if not infos:
return []
return [MediaInfo(tmdb_info=info).to_dict() for info in infos]
@ -163,11 +174,12 @@ def tmdb_trending(page: int = 1,
@router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode])
def tmdb_season_episodes(tmdbid: int, season: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询某季的所有信信息
"""
episodes_info = TmdbChain().tmdb_episodes(tmdbid=tmdbid, season=season)
episodes_info = TmdbChain(db).tmdb_episodes(tmdbid=tmdbid, season=season)
if not episodes_info:
return []
else:

View File

@ -1,24 +1,27 @@
from typing import Any
from fastapi import APIRouter, BackgroundTasks, Request
from fastapi import APIRouter, BackgroundTasks, Request, Depends
from sqlalchemy.orm import Session
from app import schemas
from app.chain.webhook import WebhookChain
from app.core.config import settings
from app.db import get_db
router = APIRouter()
def start_webhook_chain(body: Any, form: Any, args: Any):
def start_webhook_chain(db: Session, body: Any, form: Any, args: Any):
"""
启动链式任务
"""
WebhookChain().message(body=body, form=form, args=args)
WebhookChain(db).message(body=body, form=form, args=args)
@router.post("/", summary="Webhook消息响应", response_model=schemas.Response)
async def webhook_message(background_tasks: BackgroundTasks,
token: str, request: Request) -> Any:
token: str, request: Request,
db: Session = Depends(get_db),) -> Any:
"""
Webhook响应
"""
@ -27,18 +30,19 @@ async def webhook_message(background_tasks: BackgroundTasks,
body = await request.body()
form = await request.form()
args = request.query_params
background_tasks.add_task(start_webhook_chain, body, form, args)
background_tasks.add_task(start_webhook_chain, db, body, form, args)
return schemas.Response(success=True)
@router.get("/", summary="Webhook消息响应", response_model=schemas.Response)
async def webhook_message(background_tasks: BackgroundTasks,
token: str, request: Request) -> Any:
token: str, request: Request,
db: Session = Depends(get_db)) -> Any:
"""
Webhook响应
"""
if token != settings.API_TOKEN:
return schemas.Response(success=False, message="token认证不通过")
args = request.query_params
background_tasks.add_task(start_webhook_chain, None, None, args)
background_tasks.add_task(start_webhook_chain, db, None, None, args)
return schemas.Response(success=True)

View File

@ -235,11 +235,11 @@ def arr_movie_lookup(apikey: str, term: str, db: Session = Depends(get_db)) -> A
)
tmdbid = term.replace("tmdb:", "")
# 查询媒体信息
mediainfo = MediaChain().recognize_media(mtype=MediaType.MOVIE, tmdbid=int(tmdbid))
mediainfo = MediaChain(db).recognize_media(mtype=MediaType.MOVIE, tmdbid=int(tmdbid))
if not mediainfo:
return [RadarrMovie()]
# 查询是否已存在
exists = MediaChain().media_exists(mediainfo=mediainfo)
exists = MediaChain(db).media_exists(mediainfo=mediainfo)
if not exists:
# 文件不存在
hasfile = False
@ -324,7 +324,7 @@ def arr_add_movie(apikey: str,
"id": subscribe.id
}
# 添加订阅
sid, message = SubscribeChain().add(title=movie.title,
sid, message = SubscribeChain(db).add(title=movie.title,
year=movie.year,
mtype=MediaType.MOVIE,
tmdbid=movie.tmdbId,
@ -515,7 +515,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) ->
# 获取TVDBID
if not term.startswith("tvdb:"):
mediainfo = MediaChain().recognize_media(meta=MetaInfo(term),
mediainfo = MediaChain(db).recognize_media(meta=MetaInfo(term),
mtype=MediaType.TV)
if not mediainfo:
return [SonarrSeries()]
@ -527,7 +527,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) ->
tvdbid = int(term.replace("tvdb:", ""))
# 查询TVDB信息
tvdbinfo = MediaChain().tvdb_info(tvdbid=tvdbid)
tvdbinfo = MediaChain(db).tvdb_info(tvdbid=tvdbid)
if not tvdbinfo:
return [SonarrSeries()]
@ -539,11 +539,11 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) ->
# 根据TVDB查询媒体信息
if not mediainfo:
mediainfo = MediaChain().recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')),
mediainfo = MediaChain(db).recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')),
mtype=MediaType.TV)
# 查询是否存在
exists = MediaChain().media_exists(mediainfo)
exists = MediaChain(db).media_exists(mediainfo)
if exists:
hasfile = True
else:
@ -666,7 +666,7 @@ def arr_add_series(apikey: str, tv: schemas.SonarrSeries,
for season in left_seasons:
if not season.get("monitored"):
continue
sid, message = SubscribeChain().add(title=tv.title,
sid, message = SubscribeChain(db).add(title=tv.title,
year=tv.year,
season=season.get("seasonNumber"),
tmdbid=tv.tmdbId,

View File

@ -7,6 +7,7 @@ from typing import Optional, Any, Tuple, List, Set, Union, Dict
from qbittorrentapi import TorrentFilesList
from ruamel.yaml import CommentedMap
from sqlalchemy.orm import Session
from transmission_rpc import File
from app.core.config import settings
@ -27,10 +28,11 @@ class ChainBase(metaclass=ABCMeta):
处理链基类
"""
def __init__(self):
def __init__(self, db: Session = None):
"""
公共初始化
"""
self._db = db
self.modulemanager = ModuleManager()
self.eventmanager = EventManager()

View File

@ -3,6 +3,7 @@ from typing import Tuple, Optional, Union
from urllib.parse import urljoin
from lxml import etree
from sqlalchemy.orm import Session
from app.chain import ChainBase
from app.chain.site import SiteChain
@ -22,12 +23,12 @@ class CookieCloudChain(ChainBase):
CookieCloud处理链
"""
def __init__(self):
super().__init__()
self.siteoper = SiteOper()
self.siteiconoper = SiteIconOper()
def __init__(self, db: Session = None):
super().__init__(db)
self.siteoper = SiteOper(self._db)
self.siteiconoper = SiteIconOper(self._db)
self.siteshelper = SitesHelper()
self.sitechain = SiteChain()
self.sitechain = SiteChain(self._db)
self.message = MessageHelper()
self.cookiecloud = CookieCloudHelper(
server=settings.COOKIECLOUD_HOST,

View File

@ -2,6 +2,8 @@ import re
from pathlib import Path
from typing import List, Optional, Tuple, Set, Dict, Union
from sqlalchemy.orm import Session
from app.chain import ChainBase
from app.core.config import settings
from app.core.context import MediaInfo, TorrentInfo, Context
@ -20,11 +22,11 @@ class DownloadChain(ChainBase):
下载处理链
"""
def __init__(self):
super().__init__()
def __init__(self, db: Session = None):
super().__init__(db)
self.torrent = TorrentHelper()
self.downloadhis = DownloadHistoryOper()
self.mediaserver = MediaServerOper()
self.downloadhis = DownloadHistoryOper(self._db)
self.mediaserver = MediaServerOper(self._db)
def post_download_message(self, meta: MetaBase, mediainfo: MediaInfo, torrent: TorrentInfo,
channel: MessageChannel = None,

View File

@ -2,6 +2,8 @@ import json
import threading
from typing import List, Union, Generator
from sqlalchemy.orm import Session
from app import schemas
from app.chain import ChainBase
from app.core.config import settings
@ -17,9 +19,9 @@ class MediaServerChain(ChainBase):
媒体服务器处理链
"""
def __init__(self):
super().__init__()
self.mediaserverdb = MediaServerOper()
def __init__(self, db: Session = None):
super().__init__(db)
self.mediaserverdb = MediaServerOper(db)
def librarys(self) -> List[schemas.MediaServerLibrary]:
"""

View File

@ -27,12 +27,12 @@ class MessageChain(ChainBase):
# 每页数据量
_page_size: int = 8
def __init__(self):
super().__init__()
self.downloadchain = DownloadChain()
self.subscribechain = SubscribeChain()
self.searchchain = SearchChain()
self.medtachain = MediaChain()
def __init__(self, db: Session = None):
super().__init__(db)
self.downloadchain = DownloadChain(self._db)
self.subscribechain = SubscribeChain(self._db)
self.searchchain = SearchChain(self._db)
self.medtachain = MediaChain(self._db)
self.torrent = TorrentHelper()
self.eventmanager = EventManager()
self.torrenthelper = TorrentHelper()

View File

@ -4,6 +4,8 @@ import time
from datetime import datetime
from typing import Tuple, Optional
from sqlalchemy.orm import Session
from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.core.config import settings
@ -25,12 +27,12 @@ class RssChain(ChainBase):
RSS处理链
"""
def __init__(self):
super().__init__()
self.rssoper = RssOper()
def __init__(self, db: Session = None):
super().__init__(db)
self.rssoper = RssOper(self._db)
self.sites = SitesHelper()
self.systemconfig = SystemConfigOper()
self.downloadchain = DownloadChain()
self.systemconfig = SystemConfigOper(self._db)
self.downloadchain = DownloadChain(self._db)
self.message = MessageHelper()
def add(self, title: str, year: str,

View File

@ -4,6 +4,8 @@ from datetime import datetime
from typing import Dict
from typing import List, Optional
from sqlalchemy.orm import Session
from app.chain import ChainBase
from app.core.context import Context
from app.core.context import MediaInfo, TorrentInfo
@ -23,11 +25,11 @@ class SearchChain(ChainBase):
站点资源搜索处理链
"""
def __init__(self):
super().__init__()
def __init__(self, db: Session = None):
super().__init__(db)
self.siteshelper = SitesHelper()
self.progress = ProgressHelper()
self.systemconfig = SystemConfigOper()
self.systemconfig = SystemConfigOper(self._db)
self.torrenthelper = TorrentHelper()
def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None) -> List[Context]:

View File

@ -1,5 +1,7 @@
from typing import Union, Tuple
from sqlalchemy.orm import Session
from app.chain import ChainBase
from app.core.config import settings
from app.db.models.site import Site
@ -20,9 +22,9 @@ class SiteChain(ChainBase):
站点管理处理链
"""
def __init__(self):
super().__init__()
self.siteoper = SiteOper()
def __init__(self, db: Session = None):
super().__init__(db)
self.siteoper = SiteOper(self._db)
self.cookiehelper = CookieHelper()
self.message = MessageHelper()

View File

@ -3,6 +3,8 @@ import re
from datetime import datetime
from typing import Dict, List, Optional, Union, Tuple
from requests import Session
from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.search import SearchChain
@ -27,14 +29,14 @@ class SubscribeChain(ChainBase):
_cache_file = "__torrents_cache__"
def __init__(self):
super().__init__()
self.downloadchain = DownloadChain()
self.searchchain = SearchChain()
self.subscribehelper = SubscribeOper()
def __init__(self, db: Session = None):
super().__init__(db)
self.downloadchain = DownloadChain(self._db)
self.searchchain = SearchChain(self._db)
self.subscribehelper = SubscribeOper(self._db)
self.siteshelper = SitesHelper()
self.message = MessageHelper()
self.systemconfig = SystemConfigOper()
self.systemconfig = SystemConfigOper(self._db)
def add(self, title: str, year: str,
mtype: MediaType = None,

View File

@ -5,6 +5,8 @@ import threading
from pathlib import Path
from typing import List, Optional, Union
from sqlalchemy.orm import Session
from app.chain import ChainBase
from app.core.config import settings
from app.core.context import MediaInfo
@ -28,10 +30,10 @@ class TransferChain(ChainBase):
文件转移处理链
"""
def __init__(self):
super().__init__()
self.downloadhis = DownloadHistoryOper()
self.transferhis = TransferHistoryOper()
def __init__(self, db: Session = None):
super().__init__(db)
self.downloadhis = DownloadHistoryOper(self._db)
self.transferhis = TransferHistoryOper(self._db)
self.progress = ProgressHelper()
def process(self, arg_str: str = None, channel: MessageChannel = None, userid: Union[str, int] = None) -> bool:

View File

@ -33,8 +33,11 @@ class DbOper:
_db: Session = None
def __init__(self, _db=SessionLocal()):
self._db = _db
def __init__(self, db: Session = None):
if db:
self._db = db
else:
self._db = SessionLocal()
def __del__(self):
if self._db:

View File

@ -15,7 +15,7 @@ class DownloadHistoryOper(DbOper):
按路径查询下载记录
:param path: 数据key
"""
return DownloadHistory.get_by_path(self._db, path)
return DownloadHistory.get_by_path(self._db, str(path))
def get_by_hash(self, download_hash: str) -> Any:
"""

View File

@ -1,7 +1,9 @@
import json
from typing import Optional
from app.db import DbOper, SessionLocal
from sqlalchemy.orm import Session
from app.db import DbOper
from app.db.models.mediaserver import MediaServerItem
@ -10,7 +12,7 @@ class MediaServerOper(DbOper):
媒体服务器数据管理
"""
def __init__(self, db=SessionLocal()):
def __init__(self, db: Session = None):
super().__init__(db)
def add(self, **kwargs) -> bool:

View File

@ -1,6 +1,8 @@
from typing import List
from app.db import DbOper, SessionLocal
from sqlalchemy.orm import Session
from app.db import DbOper
from app.db.models.rss import Rss
@ -9,7 +11,7 @@ class RssOper(DbOper):
RSS订阅数据管理
"""
def __init__(self, db=SessionLocal()):
def __init__(self, db: Session = None):
super().__init__(db)
def add(self, **kwargs) -> bool:

View File

@ -1,22 +1,24 @@
import json
from typing import Any, Union
from app.db import DbOper, SessionLocal
from sqlalchemy.orm import Session
from app.db import DbOper
from app.db.models.systemconfig import SystemConfig
from app.schemas.types import SystemConfigKey
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
from app.schemas.types import SystemConfigKey
class SystemConfigOper(DbOper, metaclass=Singleton):
# 配置对象
__SYSTEMCONF: dict = {}
def __init__(self, _db=SessionLocal()):
def __init__(self, db: Session = None):
"""
加载配置到内存
"""
super().__init__(_db)
super().__init__(db)
for item in SystemConfig.list(self._db):
if ObjectUtils.is_obj(item.value):
self.__SYSTEMCONF[item.key] = json.loads(item.value)