fix db session

This commit is contained in:
jxxghp 2023-09-09 19:26:56 +08:00
parent ccc249f29d
commit 0e36d003c0
16 changed files with 39 additions and 49 deletions

View File

@ -6,8 +6,6 @@ Create Date: 2023-08-28 13:21:45.152012
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '52ab4930be04'

View File

@ -64,14 +64,13 @@ def wechat_verify(echostr: str, msg_signature: str,
@router.get("/switchs", summary="查询通知消息渠道开关", response_model=List[NotificationSwitch])
def read_switchs(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def read_switchs(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询通知消息渠道开关
"""
return_list = []
# 读取数据库
switchs = SystemConfigOper(db).get(SystemConfigKey.NotificationChannels)
switchs = SystemConfigOper().get(SystemConfigKey.NotificationChannels)
if not switchs:
for noti in NotificationType:
return_list.append(NotificationSwitch(mtype=noti.value, wechat=True, telegram=True, slack=True))
@ -83,7 +82,6 @@ def read_switchs(db: Session = Depends(get_db),
@router.post("/switchs", summary="设置通知消息渠道开关", response_model=schemas.Response)
def set_switchs(switchs: List[NotificationSwitch],
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询通知消息渠道开关
@ -92,6 +90,6 @@ def set_switchs(switchs: List[NotificationSwitch],
for switch in switchs:
switch_list.append(switch.dict())
# 存入数据库
SystemConfigOper(db).set(SystemConfigKey.NotificationChannels, switch_list)
SystemConfigOper().set(SystemConfigKey.NotificationChannels, switch_list)
return schemas.Response(success=True)

View File

@ -22,28 +22,26 @@ def all_plugins(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/installed", summary="已安装插件", response_model=List[str])
def installed_plugins(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def installed_plugins(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询用户已安装插件清单
"""
return SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
return SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
@router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response)
def install_plugin(plugin_id: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
安装插件
"""
# 已安装插件
install_plugins = SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 安装插件
if plugin_id not in install_plugins:
install_plugins.append(plugin_id)
# 保存设置
SystemConfigOper(db).set(SystemConfigKey.UserInstalledPlugins, install_plugins)
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
# 重载插件管理器
PluginManager().init_config()
return schemas.Response(success=True)
@ -93,19 +91,18 @@ def set_plugin_config(plugin_id: str, conf: dict,
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
def uninstall_plugin(plugin_id: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
卸载插件
"""
# 删除已安装信息
install_plugins = SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
for plugin in install_plugins:
if plugin == plugin_id:
install_plugins.remove(plugin)
break
# 保存
SystemConfigOper(db).set(SystemConfigKey.UserInstalledPlugins, install_plugins)
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
# 重载插件管理器
PluginManager().init_config()
return schemas.Response(success=True)

View File

@ -1,6 +1,6 @@
from typing import List, Any
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import schemas

View File

@ -117,9 +117,9 @@ def cookie_cloud_sync(db: Session = Depends(get_db),
清空所有站点数据并重新同步CookieCloud站点信息
"""
Site.reset(db)
SystemConfigOper(db).set(SystemConfigKey.IndexerSites, [])
SystemConfigOper(db).set(SystemConfigKey.RssSites, [])
CookieCloudChain(db).process(manual=True)
SystemConfigOper().set(SystemConfigKey.IndexerSites, [])
SystemConfigOper().set(SystemConfigKey.RssSites, [])
CookieCloudChain().process(manual=True)
# 插件站点删除
EventManager().send_event(EventType.SiteDeleted,
{
@ -203,7 +203,7 @@ def site_resource(site_id: int,
status_code=404,
detail=f"站点 {site_id} 不存在",
)
torrents = TorrentsChain(db).browse(domain=site.domain)
torrents = TorrentsChain().browse(domain=site.domain)
if not torrents:
return []
return [torrent.to_dict() for torrent in torrents]
@ -234,7 +234,7 @@ def read_rss_sites(db: Session = Depends(get_db)) -> List[dict]:
获取站点列表
"""
# 选中的rss站点
rss_sites = SystemConfigOper(db).get(SystemConfigKey.RssSites)
rss_sites = SystemConfigOper().get(SystemConfigKey.RssSites)
# 所有站点
all_site = Site.list_order_by_pri(db)
if not rss_sites or not all_site:

View File

@ -63,24 +63,22 @@ def get_progress(process_type: str, token: str):
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
def get_setting(key: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)):
"""
查询系统设置
"""
return schemas.Response(success=True, data={
"value": SystemConfigOper(db).get(key)
"value": SystemConfigOper().get(key)
})
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
def set_setting(key: str, value: Union[list, dict, str, int] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)):
"""
更新系统设置
"""
SystemConfigOper(db).set(key, value)
SystemConfigOper().set(key, value)
return schemas.Response(success=True)
@ -185,9 +183,9 @@ def ruletest(title: str,
description=subtitle,
)
if ruletype == "2":
rule_string = SystemConfigOper(db).get(SystemConfigKey.FilterRules2)
rule_string = SystemConfigOper().get(SystemConfigKey.FilterRules2)
else:
rule_string = SystemConfigOper(db).get(SystemConfigKey.FilterRules)
rule_string = SystemConfigOper().get(SystemConfigKey.FilterRules)
if not rule_string:
return schemas.Response(success=False, message="过滤规则未设置!")

View File

@ -30,7 +30,7 @@ class RssChain(ChainBase):
super().__init__(db)
self.rssoper = RssOper(self._db)
self.sites = SitesHelper()
self.systemconfig = SystemConfigOper(self._db)
self.systemconfig = SystemConfigOper()
self.downloadchain = DownloadChain(self._db)
self.message = MessageHelper()

View File

@ -29,7 +29,7 @@ class SearchChain(ChainBase):
super().__init__(db)
self.siteshelper = SitesHelper()
self.progress = ProgressHelper()
self.systemconfig = SystemConfigOper(self._db)
self.systemconfig = SystemConfigOper()
self.torrenthelper = TorrentHelper()
def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None, area: str = "title") -> List[Context]:

View File

@ -33,7 +33,7 @@ class SubscribeChain(ChainBase):
self.subscribeoper = SubscribeOper(self._db)
self.torrentschain = TorrentsChain()
self.message = MessageHelper()
self.systemconfig = SystemConfigOper(self._db)
self.systemconfig = SystemConfigOper()
def add(self, title: str, year: str,
mtype: MediaType = None,

View File

@ -1,12 +1,12 @@
from typing import Dict, List, Union
from cachetools import cached, TTLCache
from requests import Session
from app.chain import ChainBase
from app.core.config import settings
from app.core.context import TorrentInfo, Context, MediaInfo
from app.core.metainfo import MetaInfo
from app.db import SessionFactory
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.sites import SitesHelper
from app.log import logger
@ -23,10 +23,11 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
_cache_file = "__torrents_cache__"
def __init__(self, db: Session = None):
super().__init__(db)
def __init__(self):
self._db = SessionFactory()
super().__init__(self._db)
self.siteshelper = SitesHelper()
self.systemconfig = SystemConfigOper(self._db)
self.systemconfig = SystemConfigOper()
def remote_refresh(self, channel: MessageChannel, userid: Union[str, int] = None):
"""

View File

@ -13,7 +13,7 @@ from app.chain.transfer import TransferChain
from app.core.event import Event as ManagerEvent
from app.core.event import eventmanager, EventManager
from app.core.plugin import PluginManager
from app.db import ScopedSession
from app.db import SessionFactory
from app.log import logger
from app.schemas.types import EventType, MessageChannel
from app.utils.object import ObjectUtils
@ -41,7 +41,7 @@ class Command(metaclass=Singleton):
def __init__(self):
# 数据库连接
self._db = ScopedSession()
self._db = SessionFactory()
# 事件管理器
self.eventmanager = EventManager()
# 插件管理器

View File

@ -6,7 +6,7 @@ from alembic.config import Config
from app.core.config import settings
from app.core.security import get_password_hash
from app.db import Engine, ScopedSession
from app.db import Engine, SessionFactory
from app.db.models import Base
from app.db.models.user import User
from app.log import logger
@ -22,7 +22,7 @@ def init_db():
# 全量建表
Base.metadata.create_all(bind=Engine)
# 初始化超级管理员
db = ScopedSession()
db = SessionFactory()
user = User.get_by_name(db=db, name=settings.SUPERUSER)
if not user:
user = User(

View File

@ -1,9 +1,7 @@
import json
from typing import Any, Union
from sqlalchemy.orm import Session
from app.db import DbOper
from app.db import DbOper, SessionFactory
from app.db.models.systemconfig import SystemConfig
from app.schemas.types import SystemConfigKey
from app.utils.object import ObjectUtils
@ -14,11 +12,12 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
# 配置对象
__SYSTEMCONF: dict = {}
def __init__(self, db: Session = None):
def __init__(self):
"""
加载配置到内存
"""
super().__init__(db)
self._db = SessionFactory()
super().__init__(self._db)
for item in SystemConfig.list(self._db):
if ObjectUtils.is_obj(item.value):
self.__SYSTEMCONF[item.key] = json.loads(item.value)

View File

@ -592,7 +592,8 @@ class SpeedLimiter(_PluginBase):
for allow_ipv6 in allow_ipv6s:
if ipaddr in ipaddress.ip_network(allow_ipv6, strict=False):
return True
except Exception:
except Exception as err:
print(str(err))
return False
return False

View File

@ -12,7 +12,7 @@ from app.chain.rss import RssChain
from app.chain.subscribe import SubscribeChain
from app.chain.transfer import TransferChain
from app.core.config import settings
from app.db import ScopedSession
from app.db import SessionFactory
from app.log import logger
from app.utils.singleton import Singleton
from app.utils.timer import TimerUtils
@ -40,7 +40,7 @@ class Scheduler(metaclass=Singleton):
def __init__(self):
# 数据库连接
self._db = ScopedSession()
self._db = SessionFactory()
# 调试模式不启动定时服务
if settings.DEV:
return

View File

@ -5,8 +5,6 @@ import urllib3
from requests import Session, Response
from urllib3.exceptions import InsecureRequestWarning
from app.utils.ip import IpUtils
urllib3.disable_warnings(InsecureRequestWarning)