fix db session

This commit is contained in:
jxxghp
2023-10-18 08:35:16 +08:00
parent 84f5ce8a0b
commit fb78a07662
49 changed files with 170 additions and 224 deletions

View File

@ -1,7 +1,6 @@
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app import schemas
from app.chain.douban import DoubanChain
@ -10,7 +9,6 @@ from app.chain.media import MediaChain
from app.core.context import MediaInfo, Context, TorrentInfo
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.db import get_db
from app.schemas import NotExistMediaInfo, MediaType
router = APIRouter()
@ -18,19 +16,17 @@ router = APIRouter()
@router.get("/", summary="正在下载", response_model=List[schemas.DownloadingTorrent])
def read_downloading(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询正在下载的任务
"""
return DownloadChain(db).downloading()
return DownloadChain().downloading()
@router.post("/", summary="添加下载", response_model=schemas.Response)
def add_downloading(
media_in: schemas.MediaInfo,
torrent_in: schemas.TorrentInfo,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
添加下载任务
@ -49,7 +45,7 @@ def add_downloading(
media_info=mediainfo,
torrent_info=torrentinfo
)
did = DownloadChain(db).download_single(context=context)
did = DownloadChain().download_single(context=context)
return schemas.Response(success=True if did else False, data={
"download_id": did
})
@ -57,7 +53,6 @@ def add_downloading(
@router.post("/notexists", summary="查询缺失媒体信息", response_model=List[NotExistMediaInfo])
def exists(media_in: schemas.MediaInfo,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询缺失媒体信息
@ -80,7 +75,7 @@ def exists(media_in: schemas.MediaInfo,
# 查询缺失信息
if not mediainfo or not mediainfo.tmdb_id:
raise HTTPException(status_code=404, detail="媒体信息不存在")
exist_flag, no_exists = DownloadChain(db).get_no_exists_info(meta=meta, mediainfo=mediainfo)
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
if mediainfo.type == MediaType.MOVIE:
# 电影已存在时返回空列表,存在时返回空对像列表
return [] if exist_flag else [NotExistMediaInfo()]
@ -93,34 +88,31 @@ def exists(media_in: schemas.MediaInfo,
@router.get("/start/{hashString}", summary="开始任务", response_model=schemas.Response)
def start_downloading(
hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
开如下载任务
"""
ret = DownloadChain(db).set_downloading(hashString, "start")
ret = DownloadChain().set_downloading(hashString, "start")
return schemas.Response(success=True if ret else False)
@router.get("/stop/{hashString}", summary="暂停任务", response_model=schemas.Response)
def stop_downloading(
hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
控制下载任务
"""
ret = DownloadChain(db).set_downloading(hashString, "stop")
ret = DownloadChain().set_downloading(hashString, "stop")
return schemas.Response(success=True if ret else False)
@router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response)
def remove_downloading(
hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
控制下载任务
"""
ret = DownloadChain(db).remove_downloading(hashString)
ret = DownloadChain().remove_downloading(hashString)
return schemas.Response(success=True if ret else False)

View File

@ -75,10 +75,10 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
return schemas.Response(success=False, msg="记录不存在")
# 册除媒体库文件
if deletedest and history.dest:
TransferChain(db).delete_files(Path(history.dest))
TransferChain().delete_files(Path(history.dest))
# 删除源文件
if deletesrc and history.src:
TransferChain(db).delete_files(Path(history.src))
TransferChain().delete_files(Path(history.src))
# 发送事件
eventmanager.send_event(
EventType.DownloadFileDeleted,

View File

@ -35,7 +35,7 @@ async def login_access_token(
if not user:
# 请求协助认证
logger.warn("登录用户本地不匹配,尝试辅助认证 ...")
token = UserChain(db).user_authenticate(form_data.username, form_data.password)
token = UserChain().user_authenticate(form_data.username, form_data.password)
if not token:
raise HTTPException(status_code=401, detail="用户名或密码不正确")
else:

View File

@ -2,14 +2,12 @@ from typing import Union, Any, List
from fastapi import APIRouter, BackgroundTasks, Depends
from fastapi import Request
from sqlalchemy.orm import Session
from starlette.responses import PlainTextResponse
from app import schemas
from app.chain.message import MessageChain
from app.core.config import settings
from app.core.security import verify_token
from app.db import get_db
from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
@ -19,23 +17,22 @@ from app.schemas.types import SystemConfigKey, NotificationType
router = APIRouter()
def start_message_chain(db: Session, body: Any, form: Any, args: Any):
def start_message_chain(body: Any, form: Any, args: Any):
"""
启动链式任务
"""
MessageChain(db).process(body=body, form=form, args=args)
MessageChain().process(body=body, form=form, args=args)
@router.post("/", summary="接收用户消息", response_model=schemas.Response)
async def user_message(background_tasks: BackgroundTasks, request: Request,
db: Session = Depends(get_db)):
async def user_message(background_tasks: BackgroundTasks, request: Request):
"""
用户消息响应
"""
body = await request.body()
form = await request.form()
args = request.query_params
background_tasks.add_task(start_message_chain, db, body, form, args)
background_tasks.add_task(start_message_chain, body, form, args)
return schemas.Response(success=True)

View File

@ -1,25 +1,22 @@
from typing import List, Any
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import schemas
from app.chain.douban import DoubanChain
from app.chain.search import SearchChain
from app.core.security import verify_token
from app.db import get_db
from app.schemas.types import MediaType
router = APIRouter()
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
async def search_latest(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询搜索结果
"""
torrents = SearchChain(db).last_search_results()
torrents = SearchChain().last_search_results()
return [torrent.to_dict() for torrent in torrents]
@ -27,7 +24,6 @@ async def search_latest(db: Session = Depends(get_db),
def search_by_tmdbid(mediaid: str,
mtype: str = None,
area: str = "title",
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/
@ -36,16 +32,16 @@ def search_by_tmdbid(mediaid: str,
tmdbid = int(mediaid.replace("tmdb:", ""))
if mtype:
mtype = MediaType(mtype)
torrents = SearchChain(db).search_by_tmdbid(tmdbid=tmdbid, mtype=mtype, area=area)
torrents = SearchChain().search_by_tmdbid(tmdbid=tmdbid, mtype=mtype, area=area)
elif mediaid.startswith("douban:"):
doubanid = mediaid.replace("douban:", "")
# 识别豆瓣信息
context = DoubanChain().recognize_by_doubanid(doubanid)
if not context or not context.media_info or not context.media_info.tmdb_id:
return []
torrents = SearchChain(db).search_by_tmdbid(tmdbid=context.media_info.tmdb_id,
mtype=context.media_info.type,
area=area)
torrents = SearchChain().search_by_tmdbid(tmdbid=context.media_info.tmdb_id,
mtype=context.media_info.type,
area=area)
else:
return []
return [torrent.to_dict() for torrent in torrents]
@ -55,10 +51,9 @@ def search_by_tmdbid(mediaid: str,
async def search_by_title(keyword: str = None,
page: int = 0,
site: int = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
"""
torrents = SearchChain(db).search_by_title(title=keyword, page=page, site=site)
torrents = SearchChain().search_by_title(title=keyword, page=page, site=site)
return [torrent.to_dict() for torrent in torrents]

View File

@ -139,9 +139,9 @@ def update_cookie(
detail=f"站点 {site_id} 不存在!",
)
# 更新Cookie
state, message = SiteChain(db).update_cookie(site_info=site_info,
username=username,
password=password)
state, message = SiteChain().update_cookie(site_info=site_info,
username=username,
password=password)
return schemas.Response(success=state, message=message)
@ -158,7 +158,7 @@ def test_site(site_id: int,
status_code=404,
detail=f"站点 {site_id} 不存在",
)
status, message = SiteChain(db).test(site.domain)
status, message = SiteChain().test(site.domain)
return schemas.Response(success=status, message=message)

View File

@ -18,13 +18,13 @@ from app.schemas.types import MediaType
router = APIRouter()
def start_subscribe_add(db: Session, title: str, year: str,
def start_subscribe_add(title: str, year: str,
mtype: MediaType, tmdbid: int, season: int, username: str):
"""
启动订阅任务
"""
SubscribeChain(db).add(title=title, year=year,
mtype=mtype, tmdbid=tmdbid, season=season, username=username)
SubscribeChain().add(title=title, year=year,
mtype=mtype, tmdbid=tmdbid, season=season, username=username)
@router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe])
@ -45,7 +45,6 @@ def read_subscribes(
def create_subscribe(
*,
subscribe_in: schemas.Subscribe,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user),
) -> Any:
"""
@ -61,15 +60,15 @@ def create_subscribe(
title = subscribe_in.name
else:
title = None
sid, message = SubscribeChain(db).add(mtype=mtype,
title=title,
year=subscribe_in.year,
tmdbid=subscribe_in.tmdbid,
season=subscribe_in.season,
doubanid=subscribe_in.doubanid,
username=current_user.name,
best_version=subscribe_in.best_version,
exist_ok=True)
sid, message = SubscribeChain().add(mtype=mtype,
title=title,
year=subscribe_in.year,
tmdbid=subscribe_in.tmdbid,
season=subscribe_in.season,
doubanid=subscribe_in.doubanid,
username=current_user.name,
best_version=subscribe_in.best_version,
exist_ok=True)
return schemas.Response(success=True if sid else False, message=message, data={
"id": sid
})
@ -240,7 +239,6 @@ def delete_subscribe(
@router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response)
async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
authorization: str = Header(None)) -> Any:
"""
Jellyseerr/Overseerr订阅
@ -268,7 +266,6 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
# 添加订阅
if media_type == MediaType.MOVIE:
background_tasks.add_task(start_subscribe_add,
db=db,
mtype=media_type,
tmdbid=tmdbId,
title=subject,
@ -283,7 +280,6 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
break
for season in seasons:
background_tasks.add_task(start_subscribe_add,
db=db,
mtype=media_type,
tmdbid=tmdbId,
title=subject,

View File

@ -6,13 +6,11 @@ from typing import Union
import tailer
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app import schemas
from app.chain.search import SearchChain
from app.core.config import settings
from app.core.security import verify_token
from app.db import get_db
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper
@ -174,7 +172,6 @@ def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
def ruletest(title: str,
subtitle: str = None,
ruletype: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)):
"""
过滤规则测试,规则类型 1-订阅2-洗版3-搜索
@ -193,8 +190,8 @@ def ruletest(title: str,
return schemas.Response(success=False, message="优先级规则未设置!")
# 过滤
result = SearchChain(db).filter_torrents(rule_string=rule_string,
torrent_list=[torrent])
result = SearchChain().filter_torrents(rule_string=rule_string,
torrent_list=[torrent])
if not result:
return schemas.Response(success=False, message="不符合优先级规则!")
return schemas.Response(success=True, data={

View File

@ -59,7 +59,7 @@ def manual_transfer(path: str = None,
# 目的路径
if history.dest and str(history.dest) != "None":
# 删除旧的已整理文件
TransferChain(db).delete_files(Path(history.dest))
TransferChain().delete_files(Path(history.dest))
if not target:
target = history.dest
elif path:
@ -84,7 +84,7 @@ def manual_transfer(path: str = None,
offset=episode_offset,
)
# 开始转移
state, errormsg = TransferChain(db).manual_transfer(
state, errormsg = TransferChain().manual_transfer(
in_path=in_path,
target=target,
tmdbid=tmdbid,