fix db session
This commit is contained in:
parent
ccc249f29d
commit
0e36d003c0
@ -6,8 +6,6 @@ Create Date: 2023-08-28 13:21:45.152012
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = '52ab4930be04'
|
revision = '52ab4930be04'
|
||||||
|
@ -64,14 +64,13 @@ def wechat_verify(echostr: str, msg_signature: str,
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/switchs", summary="查询通知消息渠道开关", response_model=List[NotificationSwitch])
|
@router.get("/switchs", summary="查询通知消息渠道开关", response_model=List[NotificationSwitch])
|
||||||
def read_switchs(db: Session = Depends(get_db),
|
def read_switchs(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
|
||||||
"""
|
"""
|
||||||
查询通知消息渠道开关
|
查询通知消息渠道开关
|
||||||
"""
|
"""
|
||||||
return_list = []
|
return_list = []
|
||||||
# 读取数据库
|
# 读取数据库
|
||||||
switchs = SystemConfigOper(db).get(SystemConfigKey.NotificationChannels)
|
switchs = SystemConfigOper().get(SystemConfigKey.NotificationChannels)
|
||||||
if not switchs:
|
if not switchs:
|
||||||
for noti in NotificationType:
|
for noti in NotificationType:
|
||||||
return_list.append(NotificationSwitch(mtype=noti.value, wechat=True, telegram=True, slack=True))
|
return_list.append(NotificationSwitch(mtype=noti.value, wechat=True, telegram=True, slack=True))
|
||||||
@ -83,7 +82,6 @@ def read_switchs(db: Session = Depends(get_db),
|
|||||||
|
|
||||||
@router.post("/switchs", summary="设置通知消息渠道开关", response_model=schemas.Response)
|
@router.post("/switchs", summary="设置通知消息渠道开关", response_model=schemas.Response)
|
||||||
def set_switchs(switchs: List[NotificationSwitch],
|
def set_switchs(switchs: List[NotificationSwitch],
|
||||||
db: Session = Depends(get_db),
|
|
||||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||||
"""
|
"""
|
||||||
查询通知消息渠道开关
|
查询通知消息渠道开关
|
||||||
@ -92,6 +90,6 @@ def set_switchs(switchs: List[NotificationSwitch],
|
|||||||
for switch in switchs:
|
for switch in switchs:
|
||||||
switch_list.append(switch.dict())
|
switch_list.append(switch.dict())
|
||||||
# 存入数据库
|
# 存入数据库
|
||||||
SystemConfigOper(db).set(SystemConfigKey.NotificationChannels, switch_list)
|
SystemConfigOper().set(SystemConfigKey.NotificationChannels, switch_list)
|
||||||
|
|
||||||
return schemas.Response(success=True)
|
return schemas.Response(success=True)
|
||||||
|
@ -22,28 +22,26 @@ def all_plugins(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/installed", summary="已安装插件", response_model=List[str])
|
@router.get("/installed", summary="已安装插件", response_model=List[str])
|
||||||
def installed_plugins(db: Session = Depends(get_db),
|
def installed_plugins(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
|
||||||
"""
|
"""
|
||||||
查询用户已安装插件清单
|
查询用户已安装插件清单
|
||||||
"""
|
"""
|
||||||
return SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
|
return SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||||
|
|
||||||
|
|
||||||
@router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response)
|
@router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response)
|
||||||
def install_plugin(plugin_id: str,
|
def install_plugin(plugin_id: str,
|
||||||
db: Session = Depends(get_db),
|
|
||||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||||
"""
|
"""
|
||||||
安装插件
|
安装插件
|
||||||
"""
|
"""
|
||||||
# 已安装插件
|
# 已安装插件
|
||||||
install_plugins = SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
|
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||||
# 安装插件
|
# 安装插件
|
||||||
if plugin_id not in install_plugins:
|
if plugin_id not in install_plugins:
|
||||||
install_plugins.append(plugin_id)
|
install_plugins.append(plugin_id)
|
||||||
# 保存设置
|
# 保存设置
|
||||||
SystemConfigOper(db).set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||||
# 重载插件管理器
|
# 重载插件管理器
|
||||||
PluginManager().init_config()
|
PluginManager().init_config()
|
||||||
return schemas.Response(success=True)
|
return schemas.Response(success=True)
|
||||||
@ -93,19 +91,18 @@ def set_plugin_config(plugin_id: str, conf: dict,
|
|||||||
|
|
||||||
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
|
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
|
||||||
def uninstall_plugin(plugin_id: str,
|
def uninstall_plugin(plugin_id: str,
|
||||||
db: Session = Depends(get_db),
|
|
||||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||||
"""
|
"""
|
||||||
卸载插件
|
卸载插件
|
||||||
"""
|
"""
|
||||||
# 删除已安装信息
|
# 删除已安装信息
|
||||||
install_plugins = SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or []
|
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||||
for plugin in install_plugins:
|
for plugin in install_plugins:
|
||||||
if plugin == plugin_id:
|
if plugin == plugin_id:
|
||||||
install_plugins.remove(plugin)
|
install_plugins.remove(plugin)
|
||||||
break
|
break
|
||||||
# 保存
|
# 保存
|
||||||
SystemConfigOper(db).set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||||
# 重载插件管理器
|
# 重载插件管理器
|
||||||
PluginManager().init_config()
|
PluginManager().init_config()
|
||||||
return schemas.Response(success=True)
|
return schemas.Response(success=True)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import List, Any
|
from typing import List, Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app import schemas
|
from app import schemas
|
||||||
|
@ -117,9 +117,9 @@ def cookie_cloud_sync(db: Session = Depends(get_db),
|
|||||||
清空所有站点数据并重新同步CookieCloud站点信息
|
清空所有站点数据并重新同步CookieCloud站点信息
|
||||||
"""
|
"""
|
||||||
Site.reset(db)
|
Site.reset(db)
|
||||||
SystemConfigOper(db).set(SystemConfigKey.IndexerSites, [])
|
SystemConfigOper().set(SystemConfigKey.IndexerSites, [])
|
||||||
SystemConfigOper(db).set(SystemConfigKey.RssSites, [])
|
SystemConfigOper().set(SystemConfigKey.RssSites, [])
|
||||||
CookieCloudChain(db).process(manual=True)
|
CookieCloudChain().process(manual=True)
|
||||||
# 插件站点删除
|
# 插件站点删除
|
||||||
EventManager().send_event(EventType.SiteDeleted,
|
EventManager().send_event(EventType.SiteDeleted,
|
||||||
{
|
{
|
||||||
@ -203,7 +203,7 @@ def site_resource(site_id: int,
|
|||||||
status_code=404,
|
status_code=404,
|
||||||
detail=f"站点 {site_id} 不存在",
|
detail=f"站点 {site_id} 不存在",
|
||||||
)
|
)
|
||||||
torrents = TorrentsChain(db).browse(domain=site.domain)
|
torrents = TorrentsChain().browse(domain=site.domain)
|
||||||
if not torrents:
|
if not torrents:
|
||||||
return []
|
return []
|
||||||
return [torrent.to_dict() for torrent in torrents]
|
return [torrent.to_dict() for torrent in torrents]
|
||||||
@ -234,7 +234,7 @@ def read_rss_sites(db: Session = Depends(get_db)) -> List[dict]:
|
|||||||
获取站点列表
|
获取站点列表
|
||||||
"""
|
"""
|
||||||
# 选中的rss站点
|
# 选中的rss站点
|
||||||
rss_sites = SystemConfigOper(db).get(SystemConfigKey.RssSites)
|
rss_sites = SystemConfigOper().get(SystemConfigKey.RssSites)
|
||||||
# 所有站点
|
# 所有站点
|
||||||
all_site = Site.list_order_by_pri(db)
|
all_site = Site.list_order_by_pri(db)
|
||||||
if not rss_sites or not all_site:
|
if not rss_sites or not all_site:
|
||||||
|
@ -63,24 +63,22 @@ def get_progress(process_type: str, token: str):
|
|||||||
|
|
||||||
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
|
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
|
||||||
def get_setting(key: str,
|
def get_setting(key: str,
|
||||||
db: Session = Depends(get_db),
|
|
||||||
_: schemas.TokenPayload = Depends(verify_token)):
|
_: schemas.TokenPayload = Depends(verify_token)):
|
||||||
"""
|
"""
|
||||||
查询系统设置
|
查询系统设置
|
||||||
"""
|
"""
|
||||||
return schemas.Response(success=True, data={
|
return schemas.Response(success=True, data={
|
||||||
"value": SystemConfigOper(db).get(key)
|
"value": SystemConfigOper().get(key)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
|
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
|
||||||
def set_setting(key: str, value: Union[list, dict, str, int] = None,
|
def set_setting(key: str, value: Union[list, dict, str, int] = None,
|
||||||
db: Session = Depends(get_db),
|
|
||||||
_: schemas.TokenPayload = Depends(verify_token)):
|
_: schemas.TokenPayload = Depends(verify_token)):
|
||||||
"""
|
"""
|
||||||
更新系统设置
|
更新系统设置
|
||||||
"""
|
"""
|
||||||
SystemConfigOper(db).set(key, value)
|
SystemConfigOper().set(key, value)
|
||||||
return schemas.Response(success=True)
|
return schemas.Response(success=True)
|
||||||
|
|
||||||
|
|
||||||
@ -185,9 +183,9 @@ def ruletest(title: str,
|
|||||||
description=subtitle,
|
description=subtitle,
|
||||||
)
|
)
|
||||||
if ruletype == "2":
|
if ruletype == "2":
|
||||||
rule_string = SystemConfigOper(db).get(SystemConfigKey.FilterRules2)
|
rule_string = SystemConfigOper().get(SystemConfigKey.FilterRules2)
|
||||||
else:
|
else:
|
||||||
rule_string = SystemConfigOper(db).get(SystemConfigKey.FilterRules)
|
rule_string = SystemConfigOper().get(SystemConfigKey.FilterRules)
|
||||||
if not rule_string:
|
if not rule_string:
|
||||||
return schemas.Response(success=False, message="过滤规则未设置!")
|
return schemas.Response(success=False, message="过滤规则未设置!")
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ class RssChain(ChainBase):
|
|||||||
super().__init__(db)
|
super().__init__(db)
|
||||||
self.rssoper = RssOper(self._db)
|
self.rssoper = RssOper(self._db)
|
||||||
self.sites = SitesHelper()
|
self.sites = SitesHelper()
|
||||||
self.systemconfig = SystemConfigOper(self._db)
|
self.systemconfig = SystemConfigOper()
|
||||||
self.downloadchain = DownloadChain(self._db)
|
self.downloadchain = DownloadChain(self._db)
|
||||||
self.message = MessageHelper()
|
self.message = MessageHelper()
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ class SearchChain(ChainBase):
|
|||||||
super().__init__(db)
|
super().__init__(db)
|
||||||
self.siteshelper = SitesHelper()
|
self.siteshelper = SitesHelper()
|
||||||
self.progress = ProgressHelper()
|
self.progress = ProgressHelper()
|
||||||
self.systemconfig = SystemConfigOper(self._db)
|
self.systemconfig = SystemConfigOper()
|
||||||
self.torrenthelper = TorrentHelper()
|
self.torrenthelper = TorrentHelper()
|
||||||
|
|
||||||
def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None, area: str = "title") -> List[Context]:
|
def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None, area: str = "title") -> List[Context]:
|
||||||
|
@ -33,7 +33,7 @@ class SubscribeChain(ChainBase):
|
|||||||
self.subscribeoper = SubscribeOper(self._db)
|
self.subscribeoper = SubscribeOper(self._db)
|
||||||
self.torrentschain = TorrentsChain()
|
self.torrentschain = TorrentsChain()
|
||||||
self.message = MessageHelper()
|
self.message = MessageHelper()
|
||||||
self.systemconfig = SystemConfigOper(self._db)
|
self.systemconfig = SystemConfigOper()
|
||||||
|
|
||||||
def add(self, title: str, year: str,
|
def add(self, title: str, year: str,
|
||||||
mtype: MediaType = None,
|
mtype: MediaType = None,
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
from typing import Dict, List, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
from cachetools import cached, TTLCache
|
from cachetools import cached, TTLCache
|
||||||
from requests import Session
|
|
||||||
|
|
||||||
from app.chain import ChainBase
|
from app.chain import ChainBase
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.context import TorrentInfo, Context, MediaInfo
|
from app.core.context import TorrentInfo, Context, MediaInfo
|
||||||
from app.core.metainfo import MetaInfo
|
from app.core.metainfo import MetaInfo
|
||||||
|
from app.db import SessionFactory
|
||||||
from app.db.systemconfig_oper import SystemConfigOper
|
from app.db.systemconfig_oper import SystemConfigOper
|
||||||
from app.helper.sites import SitesHelper
|
from app.helper.sites import SitesHelper
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
@ -23,10 +23,11 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
|
|||||||
|
|
||||||
_cache_file = "__torrents_cache__"
|
_cache_file = "__torrents_cache__"
|
||||||
|
|
||||||
def __init__(self, db: Session = None):
|
def __init__(self):
|
||||||
super().__init__(db)
|
self._db = SessionFactory()
|
||||||
|
super().__init__(self._db)
|
||||||
self.siteshelper = SitesHelper()
|
self.siteshelper = SitesHelper()
|
||||||
self.systemconfig = SystemConfigOper(self._db)
|
self.systemconfig = SystemConfigOper()
|
||||||
|
|
||||||
def remote_refresh(self, channel: MessageChannel, userid: Union[str, int] = None):
|
def remote_refresh(self, channel: MessageChannel, userid: Union[str, int] = None):
|
||||||
"""
|
"""
|
||||||
|
@ -13,7 +13,7 @@ from app.chain.transfer import TransferChain
|
|||||||
from app.core.event import Event as ManagerEvent
|
from app.core.event import Event as ManagerEvent
|
||||||
from app.core.event import eventmanager, EventManager
|
from app.core.event import eventmanager, EventManager
|
||||||
from app.core.plugin import PluginManager
|
from app.core.plugin import PluginManager
|
||||||
from app.db import ScopedSession
|
from app.db import SessionFactory
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.schemas.types import EventType, MessageChannel
|
from app.schemas.types import EventType, MessageChannel
|
||||||
from app.utils.object import ObjectUtils
|
from app.utils.object import ObjectUtils
|
||||||
@ -41,7 +41,7 @@ class Command(metaclass=Singleton):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 数据库连接
|
# 数据库连接
|
||||||
self._db = ScopedSession()
|
self._db = SessionFactory()
|
||||||
# 事件管理器
|
# 事件管理器
|
||||||
self.eventmanager = EventManager()
|
self.eventmanager = EventManager()
|
||||||
# 插件管理器
|
# 插件管理器
|
||||||
|
@ -6,7 +6,7 @@ from alembic.config import Config
|
|||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.security import get_password_hash
|
from app.core.security import get_password_hash
|
||||||
from app.db import Engine, ScopedSession
|
from app.db import Engine, SessionFactory
|
||||||
from app.db.models import Base
|
from app.db.models import Base
|
||||||
from app.db.models.user import User
|
from app.db.models.user import User
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
@ -22,7 +22,7 @@ def init_db():
|
|||||||
# 全量建表
|
# 全量建表
|
||||||
Base.metadata.create_all(bind=Engine)
|
Base.metadata.create_all(bind=Engine)
|
||||||
# 初始化超级管理员
|
# 初始化超级管理员
|
||||||
db = ScopedSession()
|
db = SessionFactory()
|
||||||
user = User.get_by_name(db=db, name=settings.SUPERUSER)
|
user = User.get_by_name(db=db, name=settings.SUPERUSER)
|
||||||
if not user:
|
if not user:
|
||||||
user = User(
|
user = User(
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from app.db import DbOper, SessionFactory
|
||||||
|
|
||||||
from app.db import DbOper
|
|
||||||
from app.db.models.systemconfig import SystemConfig
|
from app.db.models.systemconfig import SystemConfig
|
||||||
from app.schemas.types import SystemConfigKey
|
from app.schemas.types import SystemConfigKey
|
||||||
from app.utils.object import ObjectUtils
|
from app.utils.object import ObjectUtils
|
||||||
@ -14,11 +12,12 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
|||||||
# 配置对象
|
# 配置对象
|
||||||
__SYSTEMCONF: dict = {}
|
__SYSTEMCONF: dict = {}
|
||||||
|
|
||||||
def __init__(self, db: Session = None):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
加载配置到内存
|
加载配置到内存
|
||||||
"""
|
"""
|
||||||
super().__init__(db)
|
self._db = SessionFactory()
|
||||||
|
super().__init__(self._db)
|
||||||
for item in SystemConfig.list(self._db):
|
for item in SystemConfig.list(self._db):
|
||||||
if ObjectUtils.is_obj(item.value):
|
if ObjectUtils.is_obj(item.value):
|
||||||
self.__SYSTEMCONF[item.key] = json.loads(item.value)
|
self.__SYSTEMCONF[item.key] = json.loads(item.value)
|
||||||
|
@ -592,7 +592,8 @@ class SpeedLimiter(_PluginBase):
|
|||||||
for allow_ipv6 in allow_ipv6s:
|
for allow_ipv6 in allow_ipv6s:
|
||||||
if ipaddr in ipaddress.ip_network(allow_ipv6, strict=False):
|
if ipaddr in ipaddress.ip_network(allow_ipv6, strict=False):
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as err:
|
||||||
|
print(str(err))
|
||||||
return False
|
return False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from app.chain.rss import RssChain
|
|||||||
from app.chain.subscribe import SubscribeChain
|
from app.chain.subscribe import SubscribeChain
|
||||||
from app.chain.transfer import TransferChain
|
from app.chain.transfer import TransferChain
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.db import ScopedSession
|
from app.db import SessionFactory
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.utils.singleton import Singleton
|
from app.utils.singleton import Singleton
|
||||||
from app.utils.timer import TimerUtils
|
from app.utils.timer import TimerUtils
|
||||||
@ -40,7 +40,7 @@ class Scheduler(metaclass=Singleton):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 数据库连接
|
# 数据库连接
|
||||||
self._db = ScopedSession()
|
self._db = SessionFactory()
|
||||||
# 调试模式不启动定时服务
|
# 调试模式不启动定时服务
|
||||||
if settings.DEV:
|
if settings.DEV:
|
||||||
return
|
return
|
||||||
|
@ -5,8 +5,6 @@ import urllib3
|
|||||||
from requests import Session, Response
|
from requests import Session, Response
|
||||||
from urllib3.exceptions import InsecureRequestWarning
|
from urllib3.exceptions import InsecureRequestWarning
|
||||||
|
|
||||||
from app.utils.ip import IpUtils
|
|
||||||
|
|
||||||
urllib3.disable_warnings(InsecureRequestWarning)
|
urllib3.disable_warnings(InsecureRequestWarning)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user