fix db session
This commit is contained in:
parent
ccc249f29d
commit
0e36d003c0
@ -6,8 +6,6 @@ Create Date: 2023-08-28 13:21:45.152012
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '52ab4930be04'
|
||||
|
@ -64,14 +64,13 @@ def wechat_verify(echostr: str, msg_signature: str,
|
||||
|
||||
|
||||
@router.get("/switchs", summary="查询通知消息渠道开关", response_model=List[NotificationSwitch])
|
||||
def read_switchs(db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
def read_switchs(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询通知消息渠道开关
|
||||
"""
|
||||
return_list = []
|
||||
# 读取数据库
|
||||
switchs = SystemConfigOper(db).get(SystemConfigKey.NotificationChannels)
|
||||
switchs = SystemConfigOper().get(SystemConfigKey.NotificationChannels)
|
||||
if not switchs:
|
||||
for noti in NotificationType:
|
||||
return_list.append(NotificationSwitch(mtype=noti.value, wechat=True, telegram=True, slack=True))
|
||||
@ -83,7 +82,6 @@ def read_switchs(db: Session = Depends(get_db),
|
||||
|
||||
@router.post("/switchs", summary="设置通知消息渠道开关", response_model=schemas.Response)
|
||||
def set_switchs(switchs: List[NotificationSwitch],
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询通知消息渠道开关
|
||||
@ -92,6 +90,6 @@ def set_switchs(switchs: List[NotificationSwitch],
|
||||
for switch in switchs:
|
||||
switch_list.append(switch.dict())
|
||||
# 存入数据库
|
||||
SystemConfigOper(db).set(SystemConfigKey.NotificationChannels, switch_list)
|
||||
SystemConfigOper().set(SystemConfigKey.NotificationChannels, switch_list)
|
||||
|
||||
return schemas.Response(success=True)
|
||||
|
@ -22,28 +22,26 @@ def all_plugins(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
|
||||
|
||||
@router.get("/installed", summary="已安装插件", response_model=List[str])
|
||||
def installed_plugins(db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
def installed_plugins(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
查询用户已安装插件清单
|
||||
"""
|
||||
return SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
return SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
|
||||
|
||||
@router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response)
|
||||
def install_plugin(plugin_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
安装插件
|
||||
"""
|
||||
# 已安装插件
|
||||
install_plugins = SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 安装插件
|
||||
if plugin_id not in install_plugins:
|
||||
install_plugins.append(plugin_id)
|
||||
# 保存设置
|
||||
SystemConfigOper(db).set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||
# 重载插件管理器
|
||||
PluginManager().init_config()
|
||||
return schemas.Response(success=True)
|
||||
@ -93,19 +91,18 @@ def set_plugin_config(plugin_id: str, conf: dict,
|
||||
|
||||
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
|
||||
def uninstall_plugin(plugin_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
卸载插件
|
||||
"""
|
||||
# 删除已安装信息
|
||||
install_plugins = SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
for plugin in install_plugins:
|
||||
if plugin == plugin_id:
|
||||
install_plugins.remove(plugin)
|
||||
break
|
||||
# 保存
|
||||
SystemConfigOper(db).set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||
# 重载插件管理器
|
||||
PluginManager().init_config()
|
||||
return schemas.Response(success=True)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import List, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app import schemas
|
||||
|
@ -117,9 +117,9 @@ def cookie_cloud_sync(db: Session = Depends(get_db),
|
||||
清空所有站点数据并重新同步CookieCloud站点信息
|
||||
"""
|
||||
Site.reset(db)
|
||||
SystemConfigOper(db).set(SystemConfigKey.IndexerSites, [])
|
||||
SystemConfigOper(db).set(SystemConfigKey.RssSites, [])
|
||||
CookieCloudChain(db).process(manual=True)
|
||||
SystemConfigOper().set(SystemConfigKey.IndexerSites, [])
|
||||
SystemConfigOper().set(SystemConfigKey.RssSites, [])
|
||||
CookieCloudChain().process(manual=True)
|
||||
# 插件站点删除
|
||||
EventManager().send_event(EventType.SiteDeleted,
|
||||
{
|
||||
@ -203,7 +203,7 @@ def site_resource(site_id: int,
|
||||
status_code=404,
|
||||
detail=f"站点 {site_id} 不存在",
|
||||
)
|
||||
torrents = TorrentsChain(db).browse(domain=site.domain)
|
||||
torrents = TorrentsChain().browse(domain=site.domain)
|
||||
if not torrents:
|
||||
return []
|
||||
return [torrent.to_dict() for torrent in torrents]
|
||||
@ -234,7 +234,7 @@ def read_rss_sites(db: Session = Depends(get_db)) -> List[dict]:
|
||||
获取站点列表
|
||||
"""
|
||||
# 选中的rss站点
|
||||
rss_sites = SystemConfigOper(db).get(SystemConfigKey.RssSites)
|
||||
rss_sites = SystemConfigOper().get(SystemConfigKey.RssSites)
|
||||
# 所有站点
|
||||
all_site = Site.list_order_by_pri(db)
|
||||
if not rss_sites or not all_site:
|
||||
|
@ -63,24 +63,22 @@ def get_progress(process_type: str, token: str):
|
||||
|
||||
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
|
||||
def get_setting(key: str,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
查询系统设置
|
||||
"""
|
||||
return schemas.Response(success=True, data={
|
||||
"value": SystemConfigOper(db).get(key)
|
||||
"value": SystemConfigOper().get(key)
|
||||
})
|
||||
|
||||
|
||||
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
|
||||
def set_setting(key: str, value: Union[list, dict, str, int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
更新系统设置
|
||||
"""
|
||||
SystemConfigOper(db).set(key, value)
|
||||
SystemConfigOper().set(key, value)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@ -185,9 +183,9 @@ def ruletest(title: str,
|
||||
description=subtitle,
|
||||
)
|
||||
if ruletype == "2":
|
||||
rule_string = SystemConfigOper(db).get(SystemConfigKey.FilterRules2)
|
||||
rule_string = SystemConfigOper().get(SystemConfigKey.FilterRules2)
|
||||
else:
|
||||
rule_string = SystemConfigOper(db).get(SystemConfigKey.FilterRules)
|
||||
rule_string = SystemConfigOper().get(SystemConfigKey.FilterRules)
|
||||
if not rule_string:
|
||||
return schemas.Response(success=False, message="过滤规则未设置!")
|
||||
|
||||
|
@ -30,7 +30,7 @@ class RssChain(ChainBase):
|
||||
super().__init__(db)
|
||||
self.rssoper = RssOper(self._db)
|
||||
self.sites = SitesHelper()
|
||||
self.systemconfig = SystemConfigOper(self._db)
|
||||
self.systemconfig = SystemConfigOper()
|
||||
self.downloadchain = DownloadChain(self._db)
|
||||
self.message = MessageHelper()
|
||||
|
||||
|
@ -29,7 +29,7 @@ class SearchChain(ChainBase):
|
||||
super().__init__(db)
|
||||
self.siteshelper = SitesHelper()
|
||||
self.progress = ProgressHelper()
|
||||
self.systemconfig = SystemConfigOper(self._db)
|
||||
self.systemconfig = SystemConfigOper()
|
||||
self.torrenthelper = TorrentHelper()
|
||||
|
||||
def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None, area: str = "title") -> List[Context]:
|
||||
|
@ -33,7 +33,7 @@ class SubscribeChain(ChainBase):
|
||||
self.subscribeoper = SubscribeOper(self._db)
|
||||
self.torrentschain = TorrentsChain()
|
||||
self.message = MessageHelper()
|
||||
self.systemconfig = SystemConfigOper(self._db)
|
||||
self.systemconfig = SystemConfigOper()
|
||||
|
||||
def add(self, title: str, year: str,
|
||||
mtype: MediaType = None,
|
||||
|
@ -1,12 +1,12 @@
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from cachetools import cached, TTLCache
|
||||
from requests import Session
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.core.context import TorrentInfo, Context, MediaInfo
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db import SessionFactory
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.log import logger
|
||||
@ -23,10 +23,11 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
|
||||
|
||||
_cache_file = "__torrents_cache__"
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
super().__init__(db)
|
||||
def __init__(self):
|
||||
self._db = SessionFactory()
|
||||
super().__init__(self._db)
|
||||
self.siteshelper = SitesHelper()
|
||||
self.systemconfig = SystemConfigOper(self._db)
|
||||
self.systemconfig = SystemConfigOper()
|
||||
|
||||
def remote_refresh(self, channel: MessageChannel, userid: Union[str, int] = None):
|
||||
"""
|
||||
|
@ -13,7 +13,7 @@ from app.chain.transfer import TransferChain
|
||||
from app.core.event import Event as ManagerEvent
|
||||
from app.core.event import eventmanager, EventManager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db import ScopedSession
|
||||
from app.db import SessionFactory
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, MessageChannel
|
||||
from app.utils.object import ObjectUtils
|
||||
@ -41,7 +41,7 @@ class Command(metaclass=Singleton):
|
||||
|
||||
def __init__(self):
|
||||
# 数据库连接
|
||||
self._db = ScopedSession()
|
||||
self._db = SessionFactory()
|
||||
# 事件管理器
|
||||
self.eventmanager = EventManager()
|
||||
# 插件管理器
|
||||
|
@ -6,7 +6,7 @@ from alembic.config import Config
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import get_password_hash
|
||||
from app.db import Engine, ScopedSession
|
||||
from app.db import Engine, SessionFactory
|
||||
from app.db.models import Base
|
||||
from app.db.models.user import User
|
||||
from app.log import logger
|
||||
@ -22,7 +22,7 @@ def init_db():
|
||||
# 全量建表
|
||||
Base.metadata.create_all(bind=Engine)
|
||||
# 初始化超级管理员
|
||||
db = ScopedSession()
|
||||
db = SessionFactory()
|
||||
user = User.get_by_name(db=db, name=settings.SUPERUSER)
|
||||
if not user:
|
||||
user = User(
|
||||
|
@ -1,9 +1,7 @@
|
||||
import json
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import DbOper
|
||||
from app.db import DbOper, SessionFactory
|
||||
from app.db.models.systemconfig import SystemConfig
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.object import ObjectUtils
|
||||
@ -14,11 +12,12 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
# 配置对象
|
||||
__SYSTEMCONF: dict = {}
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
def __init__(self):
|
||||
"""
|
||||
加载配置到内存
|
||||
"""
|
||||
super().__init__(db)
|
||||
self._db = SessionFactory()
|
||||
super().__init__(self._db)
|
||||
for item in SystemConfig.list(self._db):
|
||||
if ObjectUtils.is_obj(item.value):
|
||||
self.__SYSTEMCONF[item.key] = json.loads(item.value)
|
||||
|
@ -592,7 +592,8 @@ class SpeedLimiter(_PluginBase):
|
||||
for allow_ipv6 in allow_ipv6s:
|
||||
if ipaddr in ipaddress.ip_network(allow_ipv6, strict=False):
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as err:
|
||||
print(str(err))
|
||||
return False
|
||||
return False
|
||||
|
||||
|
@ -12,7 +12,7 @@ from app.chain.rss import RssChain
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.core.config import settings
|
||||
from app.db import ScopedSession
|
||||
from app.db import SessionFactory
|
||||
from app.log import logger
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.timer import TimerUtils
|
||||
@ -40,7 +40,7 @@ class Scheduler(metaclass=Singleton):
|
||||
|
||||
def __init__(self):
|
||||
# 数据库连接
|
||||
self._db = ScopedSession()
|
||||
self._db = SessionFactory()
|
||||
# 调试模式不启动定时服务
|
||||
if settings.DEV:
|
||||
return
|
||||
|
@ -5,8 +5,6 @@ import urllib3
|
||||
from requests import Session, Response
|
||||
from urllib3.exceptions import InsecureRequestWarning
|
||||
|
||||
from app.utils.ip import IpUtils
|
||||
|
||||
urllib3.disable_warnings(InsecureRequestWarning)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user