fix db session

This commit is contained in:
jxxghp 2023-10-18 08:35:16 +08:00
parent 84f5ce8a0b
commit fb78a07662
49 changed files with 170 additions and 224 deletions

View File

@ -1,7 +1,6 @@
from typing import Any, List from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.douban import DoubanChain from app.chain.douban import DoubanChain
@ -10,7 +9,6 @@ from app.chain.media import MediaChain
from app.core.context import MediaInfo, Context, TorrentInfo from app.core.context import MediaInfo, Context, TorrentInfo
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db
from app.schemas import NotExistMediaInfo, MediaType from app.schemas import NotExistMediaInfo, MediaType
router = APIRouter() router = APIRouter()
@ -18,19 +16,17 @@ router = APIRouter()
@router.get("/", summary="正在下载", response_model=List[schemas.DownloadingTorrent]) @router.get("/", summary="正在下载", response_model=List[schemas.DownloadingTorrent])
def read_downloading( def read_downloading(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询正在下载的任务 查询正在下载的任务
""" """
return DownloadChain(db).downloading() return DownloadChain().downloading()
@router.post("/", summary="添加下载", response_model=schemas.Response) @router.post("/", summary="添加下载", response_model=schemas.Response)
def add_downloading( def add_downloading(
media_in: schemas.MediaInfo, media_in: schemas.MediaInfo,
torrent_in: schemas.TorrentInfo, torrent_in: schemas.TorrentInfo,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
添加下载任务 添加下载任务
@ -49,7 +45,7 @@ def add_downloading(
media_info=mediainfo, media_info=mediainfo,
torrent_info=torrentinfo torrent_info=torrentinfo
) )
did = DownloadChain(db).download_single(context=context) did = DownloadChain().download_single(context=context)
return schemas.Response(success=True if did else False, data={ return schemas.Response(success=True if did else False, data={
"download_id": did "download_id": did
}) })
@ -57,7 +53,6 @@ def add_downloading(
@router.post("/notexists", summary="查询缺失媒体信息", response_model=List[NotExistMediaInfo]) @router.post("/notexists", summary="查询缺失媒体信息", response_model=List[NotExistMediaInfo])
def exists(media_in: schemas.MediaInfo, def exists(media_in: schemas.MediaInfo,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询缺失媒体信息 查询缺失媒体信息
@ -80,7 +75,7 @@ def exists(media_in: schemas.MediaInfo,
# 查询缺失信息 # 查询缺失信息
if not mediainfo or not mediainfo.tmdb_id: if not mediainfo or not mediainfo.tmdb_id:
raise HTTPException(status_code=404, detail="媒体信息不存在") raise HTTPException(status_code=404, detail="媒体信息不存在")
exist_flag, no_exists = DownloadChain(db).get_no_exists_info(meta=meta, mediainfo=mediainfo) exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
if mediainfo.type == MediaType.MOVIE: if mediainfo.type == MediaType.MOVIE:
# 电影已存在时返回空列表,存在时返回空对像列表 # 电影已存在时返回空列表,存在时返回空对像列表
return [] if exist_flag else [NotExistMediaInfo()] return [] if exist_flag else [NotExistMediaInfo()]
@ -93,34 +88,31 @@ def exists(media_in: schemas.MediaInfo,
@router.get("/start/{hashString}", summary="开始任务", response_model=schemas.Response) @router.get("/start/{hashString}", summary="开始任务", response_model=schemas.Response)
def start_downloading( def start_downloading(
hashString: str, hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
开如下载任务 开如下载任务
""" """
ret = DownloadChain(db).set_downloading(hashString, "start") ret = DownloadChain().set_downloading(hashString, "start")
return schemas.Response(success=True if ret else False) return schemas.Response(success=True if ret else False)
@router.get("/stop/{hashString}", summary="暂停任务", response_model=schemas.Response) @router.get("/stop/{hashString}", summary="暂停任务", response_model=schemas.Response)
def stop_downloading( def stop_downloading(
hashString: str, hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
控制下载任务 控制下载任务
""" """
ret = DownloadChain(db).set_downloading(hashString, "stop") ret = DownloadChain().set_downloading(hashString, "stop")
return schemas.Response(success=True if ret else False) return schemas.Response(success=True if ret else False)
@router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response) @router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response)
def remove_downloading( def remove_downloading(
hashString: str, hashString: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
控制下载任务 控制下载任务
""" """
ret = DownloadChain(db).remove_downloading(hashString) ret = DownloadChain().remove_downloading(hashString)
return schemas.Response(success=True if ret else False) return schemas.Response(success=True if ret else False)

View File

@ -75,10 +75,10 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
return schemas.Response(success=False, msg="记录不存在") return schemas.Response(success=False, msg="记录不存在")
# 册除媒体库文件 # 册除媒体库文件
if deletedest and history.dest: if deletedest and history.dest:
TransferChain(db).delete_files(Path(history.dest)) TransferChain().delete_files(Path(history.dest))
# 删除源文件 # 删除源文件
if deletesrc and history.src: if deletesrc and history.src:
TransferChain(db).delete_files(Path(history.src)) TransferChain().delete_files(Path(history.src))
# 发送事件 # 发送事件
eventmanager.send_event( eventmanager.send_event(
EventType.DownloadFileDeleted, EventType.DownloadFileDeleted,

View File

@ -35,7 +35,7 @@ async def login_access_token(
if not user: if not user:
# 请求协助认证 # 请求协助认证
logger.warn("登录用户本地不匹配,尝试辅助认证 ...") logger.warn("登录用户本地不匹配,尝试辅助认证 ...")
token = UserChain(db).user_authenticate(form_data.username, form_data.password) token = UserChain().user_authenticate(form_data.username, form_data.password)
if not token: if not token:
raise HTTPException(status_code=401, detail="用户名或密码不正确") raise HTTPException(status_code=401, detail="用户名或密码不正确")
else: else:

View File

@ -2,14 +2,12 @@ from typing import Union, Any, List
from fastapi import APIRouter, BackgroundTasks, Depends from fastapi import APIRouter, BackgroundTasks, Depends
from fastapi import Request from fastapi import Request
from sqlalchemy.orm import Session
from starlette.responses import PlainTextResponse from starlette.responses import PlainTextResponse
from app import schemas from app import schemas
from app.chain.message import MessageChain from app.chain.message import MessageChain
from app.core.config import settings from app.core.config import settings
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger from app.log import logger
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
@ -19,23 +17,22 @@ from app.schemas.types import SystemConfigKey, NotificationType
router = APIRouter() router = APIRouter()
def start_message_chain(db: Session, body: Any, form: Any, args: Any): def start_message_chain(body: Any, form: Any, args: Any):
""" """
启动链式任务 启动链式任务
""" """
MessageChain(db).process(body=body, form=form, args=args) MessageChain().process(body=body, form=form, args=args)
@router.post("/", summary="接收用户消息", response_model=schemas.Response) @router.post("/", summary="接收用户消息", response_model=schemas.Response)
async def user_message(background_tasks: BackgroundTasks, request: Request, async def user_message(background_tasks: BackgroundTasks, request: Request):
db: Session = Depends(get_db)):
""" """
用户消息响应 用户消息响应
""" """
body = await request.body() body = await request.body()
form = await request.form() form = await request.form()
args = request.query_params args = request.query_params
background_tasks.add_task(start_message_chain, db, body, form, args) background_tasks.add_task(start_message_chain, body, form, args)
return schemas.Response(success=True) return schemas.Response(success=True)

View File

@ -1,25 +1,22 @@
from typing import List, Any from typing import List, Any
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.douban import DoubanChain from app.chain.douban import DoubanChain
from app.chain.search import SearchChain from app.chain.search import SearchChain
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db
from app.schemas.types import MediaType from app.schemas.types import MediaType
router = APIRouter() router = APIRouter()
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context]) @router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
async def search_latest(db: Session = Depends(get_db), async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询搜索结果 查询搜索结果
""" """
torrents = SearchChain(db).last_search_results() torrents = SearchChain().last_search_results()
return [torrent.to_dict() for torrent in torrents] return [torrent.to_dict() for torrent in torrents]
@ -27,7 +24,6 @@ async def search_latest(db: Session = Depends(get_db),
def search_by_tmdbid(mediaid: str, def search_by_tmdbid(mediaid: str,
mtype: str = None, mtype: str = None,
area: str = "title", area: str = "title",
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/ 根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/
@ -36,16 +32,16 @@ def search_by_tmdbid(mediaid: str,
tmdbid = int(mediaid.replace("tmdb:", "")) tmdbid = int(mediaid.replace("tmdb:", ""))
if mtype: if mtype:
mtype = MediaType(mtype) mtype = MediaType(mtype)
torrents = SearchChain(db).search_by_tmdbid(tmdbid=tmdbid, mtype=mtype, area=area) torrents = SearchChain().search_by_tmdbid(tmdbid=tmdbid, mtype=mtype, area=area)
elif mediaid.startswith("douban:"): elif mediaid.startswith("douban:"):
doubanid = mediaid.replace("douban:", "") doubanid = mediaid.replace("douban:", "")
# 识别豆瓣信息 # 识别豆瓣信息
context = DoubanChain().recognize_by_doubanid(doubanid) context = DoubanChain().recognize_by_doubanid(doubanid)
if not context or not context.media_info or not context.media_info.tmdb_id: if not context or not context.media_info or not context.media_info.tmdb_id:
return [] return []
torrents = SearchChain(db).search_by_tmdbid(tmdbid=context.media_info.tmdb_id, torrents = SearchChain().search_by_tmdbid(tmdbid=context.media_info.tmdb_id,
mtype=context.media_info.type, mtype=context.media_info.type,
area=area) area=area)
else: else:
return [] return []
return [torrent.to_dict() for torrent in torrents] return [torrent.to_dict() for torrent in torrents]
@ -55,10 +51,9 @@ def search_by_tmdbid(mediaid: str,
async def search_by_title(keyword: str = None, async def search_by_title(keyword: str = None,
page: int = 0, page: int = 0,
site: int = None, site: int = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据名称模糊搜索站点资源支持分页关键词为空是返回首页资源 根据名称模糊搜索站点资源支持分页关键词为空是返回首页资源
""" """
torrents = SearchChain(db).search_by_title(title=keyword, page=page, site=site) torrents = SearchChain().search_by_title(title=keyword, page=page, site=site)
return [torrent.to_dict() for torrent in torrents] return [torrent.to_dict() for torrent in torrents]

View File

@ -139,9 +139,9 @@ def update_cookie(
detail=f"站点 {site_id} 不存在!", detail=f"站点 {site_id} 不存在!",
) )
# 更新Cookie # 更新Cookie
state, message = SiteChain(db).update_cookie(site_info=site_info, state, message = SiteChain().update_cookie(site_info=site_info,
username=username, username=username,
password=password) password=password)
return schemas.Response(success=state, message=message) return schemas.Response(success=state, message=message)
@ -158,7 +158,7 @@ def test_site(site_id: int,
status_code=404, status_code=404,
detail=f"站点 {site_id} 不存在", detail=f"站点 {site_id} 不存在",
) )
status, message = SiteChain(db).test(site.domain) status, message = SiteChain().test(site.domain)
return schemas.Response(success=status, message=message) return schemas.Response(success=status, message=message)

View File

@ -18,13 +18,13 @@ from app.schemas.types import MediaType
router = APIRouter() router = APIRouter()
def start_subscribe_add(db: Session, title: str, year: str, def start_subscribe_add(title: str, year: str,
mtype: MediaType, tmdbid: int, season: int, username: str): mtype: MediaType, tmdbid: int, season: int, username: str):
""" """
启动订阅任务 启动订阅任务
""" """
SubscribeChain(db).add(title=title, year=year, SubscribeChain().add(title=title, year=year,
mtype=mtype, tmdbid=tmdbid, season=season, username=username) mtype=mtype, tmdbid=tmdbid, season=season, username=username)
@router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe]) @router.get("/", summary="所有订阅", response_model=List[schemas.Subscribe])
@ -45,7 +45,6 @@ def read_subscribes(
def create_subscribe( def create_subscribe(
*, *,
subscribe_in: schemas.Subscribe, subscribe_in: schemas.Subscribe,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user),
) -> Any: ) -> Any:
""" """
@ -61,15 +60,15 @@ def create_subscribe(
title = subscribe_in.name title = subscribe_in.name
else: else:
title = None title = None
sid, message = SubscribeChain(db).add(mtype=mtype, sid, message = SubscribeChain().add(mtype=mtype,
title=title, title=title,
year=subscribe_in.year, year=subscribe_in.year,
tmdbid=subscribe_in.tmdbid, tmdbid=subscribe_in.tmdbid,
season=subscribe_in.season, season=subscribe_in.season,
doubanid=subscribe_in.doubanid, doubanid=subscribe_in.doubanid,
username=current_user.name, username=current_user.name,
best_version=subscribe_in.best_version, best_version=subscribe_in.best_version,
exist_ok=True) exist_ok=True)
return schemas.Response(success=True if sid else False, message=message, data={ return schemas.Response(success=True if sid else False, message=message, data={
"id": sid "id": sid
}) })
@ -240,7 +239,6 @@ def delete_subscribe(
@router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response) @router.post("/seerr", summary="OverSeerr/JellySeerr通知订阅", response_model=schemas.Response)
async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
authorization: str = Header(None)) -> Any: authorization: str = Header(None)) -> Any:
""" """
Jellyseerr/Overseerr订阅 Jellyseerr/Overseerr订阅
@ -268,7 +266,6 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
# 添加订阅 # 添加订阅
if media_type == MediaType.MOVIE: if media_type == MediaType.MOVIE:
background_tasks.add_task(start_subscribe_add, background_tasks.add_task(start_subscribe_add,
db=db,
mtype=media_type, mtype=media_type,
tmdbid=tmdbId, tmdbid=tmdbId,
title=subject, title=subject,
@ -283,7 +280,6 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
break break
for season in seasons: for season in seasons:
background_tasks.add_task(start_subscribe_add, background_tasks.add_task(start_subscribe_add,
db=db,
mtype=media_type, mtype=media_type,
tmdbid=tmdbId, tmdbid=tmdbId,
title=subject, title=subject,

View File

@ -6,13 +6,11 @@ from typing import Union
import tailer import tailer
from fastapi import APIRouter, HTTPException, Depends from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.search import SearchChain from app.chain.search import SearchChain
from app.core.config import settings from app.core.config import settings
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.helper.message import MessageHelper from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper from app.helper.progress import ProgressHelper
@ -174,7 +172,6 @@ def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
def ruletest(title: str, def ruletest(title: str,
subtitle: str = None, subtitle: str = None,
ruletype: str = None, ruletype: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)): _: schemas.TokenPayload = Depends(verify_token)):
""" """
过滤规则测试规则类型 1-订阅2-洗版3-搜索 过滤规则测试规则类型 1-订阅2-洗版3-搜索
@ -193,8 +190,8 @@ def ruletest(title: str,
return schemas.Response(success=False, message="优先级规则未设置!") return schemas.Response(success=False, message="优先级规则未设置!")
# 过滤 # 过滤
result = SearchChain(db).filter_torrents(rule_string=rule_string, result = SearchChain().filter_torrents(rule_string=rule_string,
torrent_list=[torrent]) torrent_list=[torrent])
if not result: if not result:
return schemas.Response(success=False, message="不符合优先级规则!") return schemas.Response(success=False, message="不符合优先级规则!")
return schemas.Response(success=True, data={ return schemas.Response(success=True, data={

View File

@ -59,7 +59,7 @@ def manual_transfer(path: str = None,
# 目的路径 # 目的路径
if history.dest and str(history.dest) != "None": if history.dest and str(history.dest) != "None":
# 删除旧的已整理文件 # 删除旧的已整理文件
TransferChain(db).delete_files(Path(history.dest)) TransferChain().delete_files(Path(history.dest))
if not target: if not target:
target = history.dest target = history.dest
elif path: elif path:
@ -84,7 +84,7 @@ def manual_transfer(path: str = None,
offset=episode_offset, offset=episode_offset,
) )
# 开始转移 # 开始转移
state, errormsg = TransferChain(db).manual_transfer( state, errormsg = TransferChain().manual_transfer(
in_path=in_path, in_path=in_path,
target=target, target=target,
tmdbid=tmdbid, tmdbid=tmdbid,

View File

@ -390,11 +390,11 @@ def arr_add_movie(apikey: str,
"id": subscribe.id "id": subscribe.id
} }
# 添加订阅 # 添加订阅
sid, message = SubscribeChain(db).add(title=movie.title, sid, message = SubscribeChain().add(title=movie.title,
year=movie.year, year=movie.year,
mtype=MediaType.MOVIE, mtype=MediaType.MOVIE,
tmdbid=movie.tmdbId, tmdbid=movie.tmdbId,
userid="Seerr") userid="Seerr")
if sid: if sid:
return { return {
"id": sid "id": sid
@ -582,7 +582,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) ->
# 获取TVDBID # 获取TVDBID
if not term.startswith("tvdb:"): if not term.startswith("tvdb:"):
mediainfo = MediaChain().recognize_media(meta=MetaInfo(term), mediainfo = MediaChain().recognize_media(meta=MetaInfo(term),
mtype=MediaType.TV) mtype=MediaType.TV)
if not mediainfo: if not mediainfo:
return [SonarrSeries()] return [SonarrSeries()]
tvdbid = mediainfo.tvdb_id tvdbid = mediainfo.tvdb_id
@ -606,7 +606,7 @@ def arr_series_lookup(apikey: str, term: str, db: Session = Depends(get_db)) ->
# 根据TVDB查询媒体信息 # 根据TVDB查询媒体信息
if not mediainfo: if not mediainfo:
mediainfo = MediaChain().recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')), mediainfo = MediaChain().recognize_media(meta=MetaInfo(tvdbinfo.get('seriesName')),
mtype=MediaType.TV) mtype=MediaType.TV)
# 查询是否存在 # 查询是否存在
exists = MediaChain().media_exists(mediainfo) exists = MediaChain().media_exists(mediainfo)
@ -732,12 +732,12 @@ def arr_add_series(apikey: str, tv: schemas.SonarrSeries,
for season in left_seasons: for season in left_seasons:
if not season.get("monitored"): if not season.get("monitored"):
continue continue
sid, message = SubscribeChain(db).add(title=tv.title, sid, message = SubscribeChain().add(title=tv.title,
year=tv.year, year=tv.year,
season=season.get("seasonNumber"), season=season.get("seasonNumber"),
tmdbid=tv.tmdbId, tmdbid=tv.tmdbId,
mtype=MediaType.TV, mtype=MediaType.TV,
userid="Seerr") userid="Seerr")
if sid: if sid:
return { return {

View File

@ -7,7 +7,6 @@ from typing import Optional, Any, Tuple, List, Set, Union, Dict
from qbittorrentapi import TorrentFilesList from qbittorrentapi import TorrentFilesList
from ruamel.yaml import CommentedMap from ruamel.yaml import CommentedMap
from sqlalchemy.orm import Session
from transmission_rpc import File from transmission_rpc import File
from app.core.config import settings from app.core.config import settings
@ -28,11 +27,10 @@ class ChainBase(metaclass=ABCMeta):
处理链基类 处理链基类
""" """
def __init__(self, db: Session = None): def __init__(self):
""" """
公共初始化 公共初始化
""" """
self._db = db
self.modulemanager = ModuleManager() self.modulemanager = ModuleManager()
self.eventmanager = EventManager() self.eventmanager = EventManager()

View File

@ -3,7 +3,6 @@ from typing import Tuple, Optional
from urllib.parse import urljoin from urllib.parse import urljoin
from lxml import etree from lxml import etree
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.site import SiteChain from app.chain.site import SiteChain
@ -25,13 +24,13 @@ class CookieCloudChain(ChainBase):
CookieCloud处理链 CookieCloud处理链
""" """
def __init__(self, db: Session = None): def __init__(self):
super().__init__(db) super().__init__()
self.siteoper = SiteOper(self._db) self.siteoper = SiteOper()
self.siteiconoper = SiteIconOper(self._db) self.siteiconoper = SiteIconOper()
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
self.rsshelper = RssHelper() self.rsshelper = RssHelper()
self.sitechain = SiteChain(self._db) self.sitechain = SiteChain()
self.message = MessageHelper() self.message = MessageHelper()
self.cookiecloud = CookieCloudHelper( self.cookiecloud = CookieCloudHelper(
server=settings.COOKIECLOUD_HOST, server=settings.COOKIECLOUD_HOST,

View File

@ -5,8 +5,6 @@ import time
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Set, Dict, Union from typing import List, Optional, Tuple, Set, Dict, Union
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.core.context import MediaInfo, TorrentInfo, Context from app.core.context import MediaInfo, TorrentInfo, Context
@ -27,11 +25,11 @@ class DownloadChain(ChainBase):
下载处理链 下载处理链
""" """
def __init__(self, db: Session = None): def __init__(self):
super().__init__(db) super().__init__()
self.torrent = TorrentHelper() self.torrent = TorrentHelper()
self.downloadhis = DownloadHistoryOper(self._db) self.downloadhis = DownloadHistoryOper()
self.mediaserver = MediaServerOper(self._db) self.mediaserver = MediaServerOper()
def post_download_message(self, meta: MetaBase, mediainfo: MediaInfo, torrent: TorrentInfo, def post_download_message(self, meta: MetaBase, mediainfo: MediaInfo, torrent: TorrentInfo,
channel: MessageChannel = None, channel: MessageChannel = None,

View File

@ -1,13 +1,10 @@
import json import json
import threading import threading
from typing import List, Union, Generator from typing import List, Union
from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.db import SessionFactory
from app.db.mediaserver_oper import MediaServerOper from app.db.mediaserver_oper import MediaServerOper
from app.log import logger from app.log import logger
@ -19,8 +16,9 @@ class MediaServerChain(ChainBase):
媒体服务器处理链 媒体服务器处理链
""" """
def __init__(self, db: Session = None): def __init__(self):
super().__init__(db) super().__init__()
self.dboper = MediaServerOper()
def librarys(self, server: str) -> List[schemas.MediaServerLibrary]: def librarys(self, server: str) -> List[schemas.MediaServerLibrary]:
""" """
@ -51,13 +49,10 @@ class MediaServerChain(ChainBase):
同步媒体库所有数据到本地数据库 同步媒体库所有数据到本地数据库
""" """
with lock: with lock:
# 媒体服务器同步使用独立的会话
_db = SessionFactory()
_dbOper = MediaServerOper(_db)
# 汇总统计 # 汇总统计
total_count = 0 total_count = 0
# 清空登记薄 # 清空登记薄
_dbOper.empty(server=settings.MEDIASERVER) self.dboper.empty(server=settings.MEDIASERVER)
# 同步黑名单 # 同步黑名单
sync_blacklist = settings.MEDIASERVER_SYNC_BLACKLIST.split( sync_blacklist = settings.MEDIASERVER_SYNC_BLACKLIST.split(
",") if settings.MEDIASERVER_SYNC_BLACKLIST else [] ",") if settings.MEDIASERVER_SYNC_BLACKLIST else []
@ -79,6 +74,7 @@ class MediaServerChain(ChainBase):
continue continue
if not item.item_id: if not item.item_id:
continue continue
logger.debug(f"正在同步 {item.title} ...")
# 计数 # 计数
library_count += 1 library_count += 1
seasoninfo = {} seasoninfo = {}
@ -93,11 +89,8 @@ class MediaServerChain(ChainBase):
item_dict = item.dict() item_dict = item.dict()
item_dict['seasoninfo'] = json.dumps(seasoninfo) item_dict['seasoninfo'] = json.dumps(seasoninfo)
item_dict['item_type'] = item_type item_dict['item_type'] = item_type
_dbOper.add(**item_dict) self.dboper.add(**item_dict)
logger.info(f"{mediaserver} 媒体库 {library.name} 同步完成,共同步数量:{library_count}") logger.info(f"{mediaserver} 媒体库 {library.name} 同步完成,共同步数量:{library_count}")
# 总数累加 # 总数累加
total_count += library_count total_count += library_count
# 关闭数据库连接
if _db:
_db.close()
logger.info("【MediaServer】媒体库数据同步完成同步数量%s" % total_count) logger.info("【MediaServer】媒体库数据同步完成同步数量%s" % total_count)

View File

@ -28,11 +28,11 @@ class MessageChain(ChainBase):
# 每页数据量 # 每页数据量
_page_size: int = 8 _page_size: int = 8
def __init__(self, db: Session = None): def __init__(self):
super().__init__(db) super().__init__()
self.downloadchain = DownloadChain(self._db) self.downloadchain = DownloadChain()
self.subscribechain = SubscribeChain(self._db) self.subscribechain = SubscribeChain()
self.searchchain = SearchChain(self._db) self.searchchain = SearchChain()
self.medtachain = MediaChain() self.medtachain = MediaChain()
self.torrent = TorrentHelper() self.torrent = TorrentHelper()
self.eventmanager = EventManager() self.eventmanager = EventManager()

View File

@ -5,8 +5,6 @@ from datetime import datetime
from typing import Dict from typing import Dict
from typing import List, Optional from typing import List, Optional
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.core.context import Context from app.core.context import Context
from app.core.context import MediaInfo, TorrentInfo from app.core.context import MediaInfo, TorrentInfo
@ -26,8 +24,8 @@ class SearchChain(ChainBase):
站点资源搜索处理链 站点资源搜索处理链
""" """
def __init__(self, db: Session = None): def __init__(self):
super().__init__(db) super().__init__()
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
self.progress = ProgressHelper() self.progress = ProgressHelper()
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper()

View File

@ -1,8 +1,6 @@
import re import re
from typing import Union, Tuple from typing import Union, Tuple
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.db.models.site import Site from app.db.models.site import Site
@ -23,9 +21,9 @@ class SiteChain(ChainBase):
站点管理处理链 站点管理处理链
""" """
def __init__(self, db: Session = None): def __init__(self):
super().__init__(db) super().__init__()
self.siteoper = SiteOper(self._db) self.siteoper = SiteOper()
self.cookiehelper = CookieHelper() self.cookiehelper = CookieHelper()
self.message = MessageHelper() self.message = MessageHelper()

View File

@ -3,8 +3,6 @@ import re
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional, Union, Tuple from typing import Dict, List, Optional, Union, Tuple
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.douban import DoubanChain from app.chain.douban import DoubanChain
from app.chain.download import DownloadChain from app.chain.download import DownloadChain
@ -27,11 +25,11 @@ class SubscribeChain(ChainBase):
订阅管理处理链 订阅管理处理链
""" """
def __init__(self, db: Session = None): def __init__(self):
super().__init__(db) super().__init__()
self.downloadchain = DownloadChain(self._db) self.downloadchain = DownloadChain()
self.searchchain = SearchChain(self._db) self.searchchain = SearchChain()
self.subscribeoper = SubscribeOper(self._db) self.subscribeoper = SubscribeOper()
self.torrentschain = TorrentsChain() self.torrentschain = TorrentsChain()
self.message = MessageHelper() self.message = MessageHelper()
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper()

View File

@ -7,7 +7,6 @@ from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.core.context import TorrentInfo, Context, MediaInfo from app.core.context import TorrentInfo, Context, MediaInfo
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.db import SessionFactory
from app.db.site_oper import SiteOper from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.helper.rss import RssHelper from app.helper.rss import RssHelper
@ -28,10 +27,9 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
_rss_file = "__rss_cache__" _rss_file = "__rss_cache__"
def __init__(self): def __init__(self):
self._db = SessionFactory() super().__init__()
super().__init__(self._db)
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
self.siteoper = SiteOper(self._db) self.siteoper = SiteOper()
self.rsshelper = RssHelper() self.rsshelper = RssHelper()
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper()

View File

@ -4,8 +4,6 @@ import threading
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union, Dict from typing import List, Optional, Tuple, Union, Dict
from sqlalchemy.orm import Session
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.media import MediaChain from app.chain.media import MediaChain
from app.chain.tmdb import TmdbChain from app.chain.tmdb import TmdbChain
@ -35,10 +33,10 @@ class TransferChain(ChainBase):
文件转移处理链 文件转移处理链
""" """
def __init__(self, db: Session = None): def __init__(self):
super().__init__(db) super().__init__()
self.downloadhis = DownloadHistoryOper(self._db) self.downloadhis = DownloadHistoryOper()
self.transferhis = TransferHistoryOper(self._db) self.transferhis = TransferHistoryOper()
self.progress = ProgressHelper() self.progress = ProgressHelper()
self.mediachain = MediaChain() self.mediachain = MediaChain()
self.tmdbchain = TmdbChain() self.tmdbchain = TmdbChain()

View File

@ -13,7 +13,6 @@ from app.chain.transfer import TransferChain
from app.core.event import Event as ManagerEvent from app.core.event import Event as ManagerEvent
from app.core.event import eventmanager, EventManager from app.core.event import eventmanager, EventManager
from app.core.plugin import PluginManager from app.core.plugin import PluginManager
from app.db import SessionFactory
from app.helper.thread import ThreadHelper from app.helper.thread import ThreadHelper
from app.log import logger from app.log import logger
from app.scheduler import Scheduler from app.scheduler import Scheduler
@ -43,14 +42,12 @@ class Command(metaclass=Singleton):
_event = threading.Event() _event = threading.Event()
def __init__(self): def __init__(self):
# 数据库连接
self._db = SessionFactory()
# 事件管理器 # 事件管理器
self.eventmanager = EventManager() self.eventmanager = EventManager()
# 插件管理器 # 插件管理器
self.pluginmanager = PluginManager() self.pluginmanager = PluginManager()
# 处理链 # 处理链
self.chain = CommandChian(self._db) self.chain = CommandChian()
# 定时服务管理 # 定时服务管理
self.scheduler = Scheduler() self.scheduler = Scheduler()
# 线程管理器 # 线程管理器
@ -64,23 +61,23 @@ class Command(metaclass=Singleton):
"category": "站点" "category": "站点"
}, },
"/sites": { "/sites": {
"func": SiteChain(self._db).remote_list, "func": SiteChain().remote_list,
"description": "查询站点", "description": "查询站点",
"category": "站点", "category": "站点",
"data": {} "data": {}
}, },
"/site_cookie": { "/site_cookie": {
"func": SiteChain(self._db).remote_cookie, "func": SiteChain().remote_cookie,
"description": "更新站点Cookie", "description": "更新站点Cookie",
"data": {} "data": {}
}, },
"/site_enable": { "/site_enable": {
"func": SiteChain(self._db).remote_enable, "func": SiteChain().remote_enable,
"description": "启用站点", "description": "启用站点",
"data": {} "data": {}
}, },
"/site_disable": { "/site_disable": {
"func": SiteChain(self._db).remote_disable, "func": SiteChain().remote_disable,
"description": "禁用站点", "description": "禁用站点",
"data": {} "data": {}
}, },
@ -91,7 +88,7 @@ class Command(metaclass=Singleton):
"category": "管理" "category": "管理"
}, },
"/subscribes": { "/subscribes": {
"func": SubscribeChain(self._db).remote_list, "func": SubscribeChain().remote_list,
"description": "查询订阅", "description": "查询订阅",
"category": "订阅", "category": "订阅",
"data": {} "data": {}
@ -109,7 +106,7 @@ class Command(metaclass=Singleton):
"category": "订阅" "category": "订阅"
}, },
"/subscribe_delete": { "/subscribe_delete": {
"func": SubscribeChain(self._db).remote_delete, "func": SubscribeChain().remote_delete,
"description": "删除订阅", "description": "删除订阅",
"data": {} "data": {}
}, },
@ -119,7 +116,7 @@ class Command(metaclass=Singleton):
"description": "订阅元数据更新" "description": "订阅元数据更新"
}, },
"/downloading": { "/downloading": {
"func": DownloadChain(self._db).remote_downloading, "func": DownloadChain().remote_downloading,
"description": "正在下载", "description": "正在下载",
"category": "管理", "category": "管理",
"data": {} "data": {}
@ -131,7 +128,7 @@ class Command(metaclass=Singleton):
"category": "管理" "category": "管理"
}, },
"/redo": { "/redo": {
"func": TransferChain(self._db).remote_transfer, "func": TransferChain().remote_transfer,
"description": "手动整理", "description": "手动整理",
"data": {} "data": {}
}, },
@ -277,8 +274,6 @@ class Command(metaclass=Singleton):
""" """
self._event.set() self._event.set()
self._thread.join() self._thread.join()
if self._db:
self._db.close()
def get_commands(self): def get_commands(self):
""" """

View File

@ -4,7 +4,6 @@ import cn2an
import regex as re import regex as re
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey from app.schemas.types import SystemConfigKey
from app.utils.singleton import Singleton from app.utils.singleton import Singleton

View File

@ -22,16 +22,15 @@ def init_db():
# 全量建表 # 全量建表
Base.metadata.create_all(bind=Engine) Base.metadata.create_all(bind=Engine)
# 初始化超级管理员 # 初始化超级管理员
db = SessionFactory() with SessionFactory() as db:
user = User.get_by_name(db=db, name=settings.SUPERUSER) user = User.get_by_name(db=db, name=settings.SUPERUSER)
if not user: if not user:
user = User( user = User(
name=settings.SUPERUSER, name=settings.SUPERUSER,
hashed_password=get_password_hash(settings.SUPERUSER_PASSWORD), hashed_password=get_password_hash(settings.SUPERUSER_PASSWORD),
is_superuser=True, is_superuser=True,
) )
user.create(db) user.create(db)
db.close()
def update_db(): def update_db():

View File

@ -31,6 +31,12 @@ class SiteOper(DbOper):
""" """
return Site.list(self._db) return Site.list(self._db)
def list_order_by_pri(self) -> List[Site]:
"""
获取站点列表
"""
return Site.list_order_by_pri(self._db)
def list_active(self) -> List[Site]: def list_active(self) -> List[Site]:
""" """
按状态获取站点列表 按状态获取站点列表

View File

@ -1,7 +1,7 @@
import json import json
from typing import Any, Union from typing import Any, Union
from app.db import DbOper, SessionFactory from app.db import DbOper
from app.db.models.systemconfig import SystemConfig from app.db.models.systemconfig import SystemConfig
from app.schemas.types import SystemConfigKey from app.schemas.types import SystemConfigKey
from app.utils.object import ObjectUtils from app.utils.object import ObjectUtils
@ -16,8 +16,7 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
""" """
加载配置到内存 加载配置到内存
""" """
self._db = SessionFactory() super().__init__()
super().__init__(self._db)
for item in SystemConfig.list(self._db): for item in SystemConfig.list(self._db):
if ObjectUtils.is_obj(item.value): if ObjectUtils.is_obj(item.value):
self.__SYSTEMCONF[item.key] = json.loads(item.value) self.__SYSTEMCONF[item.key] = json.loads(item.value)

Binary file not shown.

View File

@ -5,7 +5,6 @@ from typing import Any, List, Dict, Tuple
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.core.event import EventManager from app.core.event import EventManager
from app.db import SessionFactory
from app.db.models import Base from app.db.models import Base
from app.db.plugindata_oper import PluginDataOper from app.db.plugindata_oper import PluginDataOper
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
@ -36,12 +35,10 @@ class _PluginBase(metaclass=ABCMeta):
plugin_desc: str = "" plugin_desc: str = ""
def __init__(self): def __init__(self):
# 数据库连接
self.db = SessionFactory()
# 插件数据 # 插件数据
self.plugindata = PluginDataOper(self.db) self.plugindata = PluginDataOper()
# 处理链 # 处理链
self.chain = PluginChian(self.db) self.chain = PluginChian()
# 系统配置 # 系统配置
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper()
# 系统消息 # 系统消息
@ -186,8 +183,4 @@ class _PluginBase(metaclass=ABCMeta):
)) ))
def close(self): def close(self):
""" pass
关闭数据库连接
"""
if self.db:
self.db.close()

View File

@ -72,8 +72,8 @@ class AutoClean(_PluginBase):
# 加载模块 # 加载模块
if self._enabled: if self._enabled:
self._downloadhis = DownloadHistoryOper(self.db) self._downloadhis = DownloadHistoryOper()
self._transferhis = TransferHistoryOper(self.db) self._transferhis = TransferHistoryOper()
# 定时服务 # 定时服务
self._scheduler = BackgroundScheduler(timezone=settings.TZ) self._scheduler = BackgroundScheduler(timezone=settings.TZ)
@ -181,12 +181,12 @@ class AutoClean(_PluginBase):
for history in transferhis_list: for history in transferhis_list:
# 册除媒体库文件 # 册除媒体库文件
if str(self._cleantype == "dest") or str(self._cleantype == "all"): if str(self._cleantype == "dest") or str(self._cleantype == "all"):
TransferChain(self.db).delete_files(Path(history.dest)) TransferChain().delete_files(Path(history.dest))
# 删除记录 # 删除记录
self._transferhis.delete(history.id) self._transferhis.delete(history.id)
# 删除源文件 # 删除源文件
if str(self._cleantype == "src") or str(self._cleantype == "all"): if str(self._cleantype == "src") or str(self._cleantype == "all"):
TransferChain(self.db).delete_files(Path(history.src)) TransferChain().delete_files(Path(history.src))
# 发送事件 # 发送事件
eventmanager.send_event( eventmanager.send_event(
EventType.DownloadFileDeleted, EventType.DownloadFileDeleted,

View File

@ -15,6 +15,7 @@ from app import schemas
from app.core.config import settings from app.core.config import settings
from app.core.event import EventManager, eventmanager, Event from app.core.event import EventManager, eventmanager, Event
from app.db.models.site import Site from app.db.models.site import Site
from app.db.site_oper import SiteOper
from app.helper.browser import PlaywrightHelper from app.helper.browser import PlaywrightHelper
from app.helper.cloudflare import under_challenge from app.helper.cloudflare import under_challenge
from app.helper.module import ModuleHelper from app.helper.module import ModuleHelper
@ -52,6 +53,7 @@ class AutoSignIn(_PluginBase):
# 私有属性 # 私有属性
sites: SitesHelper = None sites: SitesHelper = None
siteoper: SiteOper = None
# 事件管理器 # 事件管理器
event: EventManager = None event: EventManager = None
# 定时器 # 定时器
@ -74,6 +76,7 @@ class AutoSignIn(_PluginBase):
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self.sites = SitesHelper() self.sites = SitesHelper()
self.siteoper = SiteOper()
self.event = EventManager() self.event = EventManager()
# 停止现有任务 # 停止现有任务
@ -248,7 +251,7 @@ class AutoSignIn(_PluginBase):
customSites = self.__custom_sites() customSites = self.__custom_sites()
site_options = ([{"title": site.name, "value": site.id} site_options = ([{"title": site.name, "value": site.id}
for site in Site.list_order_by_pri(self.db)] for site in self.siteoper.list_order_by_pri()]
+ [{"title": site.get("name"), "value": site.get("id")} + [{"title": site.get("name"), "value": site.get("id")}
for site in customSites]) for site in customSites])
return [ return [
@ -456,7 +459,7 @@ class AutoSignIn(_PluginBase):
"retry_keyword": "错误|失败" "retry_keyword": "错误|失败"
} }
def __custom_sites(self) -> List[dict]: def __custom_sites(self) -> List[Any]:
custom_sites = [] custom_sites = []
custom_sites_config = self.get_config("CustomSites") custom_sites_config = self.get_config("CustomSites")
if custom_sites_config and custom_sites_config.get("enabled"): if custom_sites_config and custom_sites_config.get("enabled"):

View File

@ -62,7 +62,7 @@ class BestFilmVersion(_PluginBase):
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self._cache_path = settings.TEMP_PATH / "__best_film_version_cache__" self._cache_path = settings.TEMP_PATH / "__best_film_version_cache__"
self.subscribechain = SubscribeChain(self.db) self.subscribechain = SubscribeChain()
# 停止现有任务 # 停止现有任务
self.stop_service() self.stop_service()

View File

@ -96,9 +96,9 @@ class DirMonitor(_PluginBase):
_event = threading.Event() _event = threading.Event()
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self.transferhis = TransferHistoryOper(self.db) self.transferhis = TransferHistoryOper()
self.downloadhis = DownloadHistoryOper(self.db) self.downloadhis = DownloadHistoryOper()
self.transferchian = TransferChain(self.db) self.transferchian = TransferChain()
self.tmdbchain = TmdbChain() self.tmdbchain = TmdbChain()
# 清空配置 # 清空配置
self._dirconf = {} self._dirconf = {}

View File

@ -66,8 +66,8 @@ class DoubanRank(_PluginBase):
_clearflag = False _clearflag = False
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self.downloadchain = DownloadChain(self.db) self.downloadchain = DownloadChain()
self.subscribechain = SubscribeChain(self.db) self.subscribechain = SubscribeChain()
if config: if config:
self._enabled = config.get("enabled") self._enabled = config.get("enabled")

View File

@ -66,9 +66,9 @@ class DoubanSync(_PluginBase):
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self.rsshelper = RssHelper() self.rsshelper = RssHelper()
self.downloadchain = DownloadChain(self.db) self.downloadchain = DownloadChain()
self.searchchain = SearchChain(self.db) self.searchchain = SearchChain()
self.subscribechain = SubscribeChain(self.db) self.subscribechain = SubscribeChain()
# 停止现有任务 # 停止现有任务
self.stop_service() self.stop_service()

View File

@ -57,7 +57,7 @@ class DownloadingMsg(_PluginBase):
# 加载模块 # 加载模块
if self._enabled: if self._enabled:
self._downloadhis = DownloadHistoryOper(self.db) self._downloadhis = DownloadHistoryOper()
# 定时服务 # 定时服务
self._scheduler = BackgroundScheduler(timezone=settings.TZ) self._scheduler = BackgroundScheduler(timezone=settings.TZ)
@ -80,7 +80,7 @@ class DownloadingMsg(_PluginBase):
定时推送正在下载进度 定时推送正在下载进度
""" """
# 正在下载种子 # 正在下载种子
torrents = DownloadChain(self.db).list_torrents(status=TorrentStatus.DOWNLOADING) torrents = DownloadChain().list_torrents(status=TorrentStatus.DOWNLOADING)
if not torrents: if not torrents:
logger.info("当前没有正在下载的任务!") logger.info("当前没有正在下载的任务!")
return return

View File

@ -11,6 +11,7 @@ from lxml import etree
from ruamel.yaml import CommentedMap from ruamel.yaml import CommentedMap
from app.core.config import settings from app.core.config import settings
from app.db.site_oper import SiteOper
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper
from app.core.event import eventmanager from app.core.event import eventmanager
@ -55,6 +56,7 @@ class IYUUAutoSeed(_PluginBase):
qb = None qb = None
tr = None tr = None
sites = None sites = None
siteoper = None
torrent = None torrent = None
# 开关 # 开关
_enabled = False _enabled = False
@ -96,6 +98,7 @@ class IYUUAutoSeed(_PluginBase):
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self.sites = SitesHelper() self.sites = SitesHelper()
self.siteoper = SiteOper()
self.torrent = TorrentHelper() self.torrent = TorrentHelper()
# 读取配置 # 读取配置
if config: if config:
@ -176,7 +179,7 @@ class IYUUAutoSeed(_PluginBase):
""" """
# 站点的可选项 # 站点的可选项
site_options = [{"title": site.name, "value": site.id} site_options = [{"title": site.name, "value": site.id}
for site in Site.list_order_by_pri(self.db)] for site in self.siteoper.list_order_by_pri()]
return [ return [
{ {
'component': 'VForm', 'component': 'VForm',

View File

@ -70,7 +70,7 @@ class LibraryScraper(_PluginBase):
# 启动定时任务 & 立即运行一次 # 启动定时任务 & 立即运行一次
if self._enabled or self._onlyonce: if self._enabled or self._onlyonce:
self.transferhis = TransferHistoryOper(self.db) self.transferhis = TransferHistoryOper()
self._scheduler = BackgroundScheduler(timezone=settings.TZ) self._scheduler = BackgroundScheduler(timezone=settings.TZ)
if self._cron: if self._cron:
logger.info(f"媒体库刮削服务启动,周期:{self._cron}") logger.info(f"媒体库刮削服务启动,周期:{self._cron}")

View File

@ -62,7 +62,7 @@ class MediaSyncDel(_PluginBase):
tr = None tr = None
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self._transferchain = TransferChain(self.db) self._transferchain = TransferChain()
self._transferhis = self._transferchain.transferhis self._transferhis = self._transferchain.transferhis
self._downloadhis = self._transferchain.downloadhis self._downloadhis = self._transferchain.downloadhis
self.episode = Episode() self.episode = Episode()

View File

@ -46,9 +46,9 @@ class NAStoolSync(_PluginBase):
_transfer = False _transfer = False
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self._transferhistory = TransferHistoryOper(self.db) self._transferhistory = TransferHistoryOper()
self._plugindata = PluginDataOper(self.db) self._plugindata = PluginDataOper()
self._downloadhistory = DownloadHistoryOper(self.db) self._downloadhistory = DownloadHistoryOper()
if config: if config:
self._clear = config.get("clear") self._clear = config.get("clear")

View File

@ -69,7 +69,7 @@ class PersonMeta(_PluginBase):
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self.tmdbchain = TmdbChain() self.tmdbchain = TmdbChain()
self.mschain = MediaServerChain(self.db) self.mschain = MediaServerChain()
if config: if config:
self._enabled = config.get("enabled") self._enabled = config.get("enabled")
self._onlyonce = config.get("onlyonce") self._onlyonce = config.get("onlyonce")

View File

@ -69,9 +69,9 @@ class RssSubscribe(_PluginBase):
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self.rsshelper = RssHelper() self.rsshelper = RssHelper()
self.downloadchain = DownloadChain(self.db) self.downloadchain = DownloadChain()
self.searchchain = SearchChain(self.db) self.searchchain = SearchChain()
self.subscribechain = SubscribeChain(self.db) self.subscribechain = SubscribeChain()
# 停止现有任务 # 停止现有任务
self.stop_service() self.stop_service()

View File

@ -16,6 +16,7 @@ from app.core.config import settings
from app.core.event import Event from app.core.event import Event
from app.core.event import eventmanager from app.core.event import eventmanager
from app.db.models.site import Site from app.db.models.site import Site
from app.db.site_oper import SiteOper
from app.helper.browser import PlaywrightHelper from app.helper.browser import PlaywrightHelper
from app.helper.module import ModuleHelper from app.helper.module import ModuleHelper
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper
@ -56,6 +57,7 @@ class SiteStatistic(_PluginBase):
# 私有属性 # 私有属性
sites = None sites = None
siteoper = None
_scheduler: Optional[BackgroundScheduler] = None _scheduler: Optional[BackgroundScheduler] = None
_last_update_time: Optional[datetime] = None _last_update_time: Optional[datetime] = None
_sites_data: dict = {} _sites_data: dict = {}
@ -72,6 +74,7 @@ class SiteStatistic(_PluginBase):
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
self.sites = SitesHelper() self.sites = SitesHelper()
self.siteoper = SiteOper()
# 停止现有任务 # 停止现有任务
self.stop_service() self.stop_service()
@ -187,7 +190,7 @@ class SiteStatistic(_PluginBase):
customSites = self.__custom_sites() customSites = self.__custom_sites()
site_options = ([{"title": site.name, "value": site.id} site_options = ([{"title": site.name, "value": site.id}
for site in Site.list_order_by_pri(self.db)] for site in self.siteoper.list_order_by_pri()]
+ [{"title": site.get("name"), "value": site.get("id")} + [{"title": site.get("name"), "value": site.get("id")}
for site in customSites]) for site in customSites])
@ -1122,7 +1125,7 @@ class SiteStatistic(_PluginBase):
self.save_data("last_update_time", key) self.save_data("last_update_time", key)
logger.info("站点数据刷新完成") logger.info("站点数据刷新完成")
def __custom_sites(self) -> List[dict]: def __custom_sites(self) -> List[Any]:
custom_sites = [] custom_sites = []
custom_sites_config = self.get_config("CustomSites") custom_sites_config = self.get_config("CustomSites")
if custom_sites_config and custom_sites_config.get("enabled"): if custom_sites_config and custom_sites_config.get("enabled"):

View File

@ -59,8 +59,8 @@ class SyncDownloadFiles(_PluginBase):
self.qb = Qbittorrent() self.qb = Qbittorrent()
self.tr = Transmission() self.tr = Transmission()
self.downloadhis = DownloadHistoryOper(self.db) self.downloadhis = DownloadHistoryOper()
self.transferhis = TransferHistoryOper(self.db) self.transferhis = TransferHistoryOper()
if config: if config:
self._enabled = config.get('enabled') self._enabled = config.get('enabled')

View File

@ -15,7 +15,6 @@ from app.chain.subscribe import SubscribeChain
from app.chain.tmdb import TmdbChain from app.chain.tmdb import TmdbChain
from app.chain.transfer import TransferChain from app.chain.transfer import TransferChain
from app.core.config import settings from app.core.config import settings
from app.db import SessionFactory
from app.log import logger from app.log import logger
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
from app.utils.timer import TimerUtils from app.utils.timer import TimerUtils
@ -44,32 +43,30 @@ class Scheduler(metaclass=Singleton):
_event = threading.Event() _event = threading.Event()
def __init__(self): def __init__(self):
# 数据库连接
self._db = SessionFactory()
# 各服务的运行状态 # 各服务的运行状态
self._jobs = { self._jobs = {
"cookiecloud": { "cookiecloud": {
"func": CookieCloudChain(self._db).process, "func": CookieCloudChain().process,
"running": False, "running": False,
}, },
"mediaserver_sync": { "mediaserver_sync": {
"func": MediaServerChain(self._db).sync, "func": MediaServerChain().sync,
"running": False, "running": False,
}, },
"subscribe_tmdb": { "subscribe_tmdb": {
"func": SubscribeChain(self._db).check, "func": SubscribeChain().check,
"running": False, "running": False,
}, },
"subscribe_search": { "subscribe_search": {
"func": SubscribeChain(self._db).search, "func": SubscribeChain().search,
"running": False, "running": False,
}, },
"subscribe_refresh": { "subscribe_refresh": {
"func": SubscribeChain(self._db).refresh, "func": SubscribeChain().refresh,
"running": False, "running": False,
}, },
"transfer": { "transfer": {
"func": TransferChain(self._db).process, "func": TransferChain().process,
"running": False, "running": False,
} }
} }
@ -189,7 +186,7 @@ class Scheduler(metaclass=Singleton):
# 后台刷新TMDB壁纸 # 后台刷新TMDB壁纸
self._scheduler.add_job( self._scheduler.add_job(
TmdbChain(self._db).get_random_wallpager, TmdbChain().get_random_wallpager,
"interval", "interval",
minutes=30, minutes=30,
next_run_time=datetime.now(pytz.timezone(settings.TZ)) + timedelta(seconds=3) next_run_time=datetime.now(pytz.timezone(settings.TZ)) + timedelta(seconds=3)
@ -197,7 +194,7 @@ class Scheduler(metaclass=Singleton):
# 公共定时服务 # 公共定时服务
self._scheduler.add_job( self._scheduler.add_job(
SchedulerChain(self._db).scheduler_job, SchedulerChain().scheduler_job,
"interval", "interval",
minutes=10 minutes=10
) )
@ -264,5 +261,3 @@ class Scheduler(metaclass=Singleton):
self._event.set() self._event.set()
if self._scheduler.running: if self._scheduler.running:
self._scheduler.shutdown() self._scheduler.shutdown()
if self._db:
self._db.close()

View File

@ -6,8 +6,6 @@ Create Date: 2023-09-19 21:34:41.994617
""" """
from alembic import op from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '232dfa044617' revision = '232dfa044617'

View File

@ -6,8 +6,6 @@ Create Date: 2023-09-23 08:25:59.776488
""" """
from alembic import op from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '30329639c12b' revision = '30329639c12b'

View File

@ -8,7 +8,6 @@ Create Date: 2023-09-28 10:15:58.410003
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'b2f011d3a8b7' revision = 'b2f011d3a8b7'
down_revision = '30329639c12b' down_revision = '30329639c12b'
@ -25,5 +24,6 @@ def upgrade() -> None:
pass pass
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
pass pass