diff --git a/alembic/versions/52ab4930be04_1_0_3.py b/alembic/versions/52ab4930be04_1_0_3.py index 956b0079..434a8f6f 100644 --- a/alembic/versions/52ab4930be04_1_0_3.py +++ b/alembic/versions/52ab4930be04_1_0_3.py @@ -6,8 +6,6 @@ Create Date: 2023-08-28 13:21:45.152012 """ from alembic import op -import sqlalchemy as sa - # revision identifiers, used by Alembic. revision = '52ab4930be04' diff --git a/app/api/endpoints/message.py b/app/api/endpoints/message.py index 9886ce76..691f0262 100644 --- a/app/api/endpoints/message.py +++ b/app/api/endpoints/message.py @@ -64,14 +64,13 @@ def wechat_verify(echostr: str, msg_signature: str, @router.get("/switchs", summary="查询通知消息渠道开关", response_model=List[NotificationSwitch]) -def read_switchs(db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +def read_switchs(_: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询通知消息渠道开关 """ return_list = [] # 读取数据库 - switchs = SystemConfigOper(db).get(SystemConfigKey.NotificationChannels) + switchs = SystemConfigOper().get(SystemConfigKey.NotificationChannels) if not switchs: for noti in NotificationType: return_list.append(NotificationSwitch(mtype=noti.value, wechat=True, telegram=True, slack=True)) @@ -83,7 +82,6 @@ def read_switchs(db: Session = Depends(get_db), @router.post("/switchs", summary="设置通知消息渠道开关", response_model=schemas.Response) def set_switchs(switchs: List[NotificationSwitch], - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询通知消息渠道开关 @@ -92,6 +90,6 @@ def set_switchs(switchs: List[NotificationSwitch], for switch in switchs: switch_list.append(switch.dict()) # 存入数据库 - SystemConfigOper(db).set(SystemConfigKey.NotificationChannels, switch_list) + SystemConfigOper().set(SystemConfigKey.NotificationChannels, switch_list) return schemas.Response(success=True) diff --git a/app/api/endpoints/plugin.py b/app/api/endpoints/plugin.py index 7a7c4208..47615c96 100644 --- a/app/api/endpoints/plugin.py +++ b/app/api/endpoints/plugin.py @@ -22,28 +22,26 @@ def all_plugins(_: schemas.TokenPayload = Depends(verify_token)) -> Any: @router.get("/installed", summary="已安装插件", response_model=List[str]) -def installed_plugins(db: Session = Depends(get_db), - _: schemas.TokenPayload = Depends(verify_token)) -> Any: +def installed_plugins(_: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询用户已安装插件清单 """ - return SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or [] + return SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or [] @router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response) def install_plugin(plugin_id: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 安装插件 """ # 已安装插件 - install_plugins = SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or [] + install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or [] # 安装插件 if plugin_id not in install_plugins: install_plugins.append(plugin_id) # 保存设置 - SystemConfigOper(db).set(SystemConfigKey.UserInstalledPlugins, install_plugins) + SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins) # 重载插件管理器 PluginManager().init_config() return schemas.Response(success=True) @@ -93,19 +91,18 @@ def set_plugin_config(plugin_id: str, conf: dict, @router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response) def uninstall_plugin(plugin_id: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 卸载插件 """ # 删除已安装信息 - install_plugins = SystemConfigOper(db).get(SystemConfigKey.UserInstalledPlugins) or [] + install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or [] for plugin in install_plugins: if plugin == plugin_id: install_plugins.remove(plugin) break # 保存 - SystemConfigOper(db).set(SystemConfigKey.UserInstalledPlugins, install_plugins) + SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins) # 重载插件管理器 PluginManager().init_config() return schemas.Response(success=True) diff --git a/app/api/endpoints/search.py b/app/api/endpoints/search.py index 3a2ae68e..406032ea 100644 --- a/app/api/endpoints/search.py +++ b/app/api/endpoints/search.py @@ -1,6 +1,6 @@ from typing import List, Any -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from app import schemas diff --git a/app/api/endpoints/site.py b/app/api/endpoints/site.py index d3e1840f..3e078af2 100644 --- a/app/api/endpoints/site.py +++ b/app/api/endpoints/site.py @@ -117,9 +117,9 @@ def cookie_cloud_sync(db: Session = Depends(get_db), 清空所有站点数据并重新同步CookieCloud站点信息 """ Site.reset(db) - SystemConfigOper(db).set(SystemConfigKey.IndexerSites, []) - SystemConfigOper(db).set(SystemConfigKey.RssSites, []) - CookieCloudChain(db).process(manual=True) + SystemConfigOper().set(SystemConfigKey.IndexerSites, []) + SystemConfigOper().set(SystemConfigKey.RssSites, []) + CookieCloudChain().process(manual=True) # 插件站点删除 EventManager().send_event(EventType.SiteDeleted, { @@ -203,7 +203,7 @@ def site_resource(site_id: int, status_code=404, detail=f"站点 {site_id} 不存在", ) - torrents = TorrentsChain(db).browse(domain=site.domain) + torrents = TorrentsChain().browse(domain=site.domain) if not torrents: return [] return [torrent.to_dict() for torrent in torrents] @@ -234,7 +234,7 @@ def read_rss_sites(db: Session = Depends(get_db)) -> List[dict]: 获取站点列表 """ # 选中的rss站点 - rss_sites = SystemConfigOper(db).get(SystemConfigKey.RssSites) + rss_sites = SystemConfigOper().get(SystemConfigKey.RssSites) # 所有站点 all_site = Site.list_order_by_pri(db) if not rss_sites or not all_site: diff --git a/app/api/endpoints/system.py b/app/api/endpoints/system.py index e401c207..df32ee6d 100644 --- a/app/api/endpoints/system.py +++ b/app/api/endpoints/system.py @@ -63,24 +63,22 @@ def get_progress(process_type: str, token: str): @router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response) def get_setting(key: str, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)): """ 查询系统设置 """ return schemas.Response(success=True, data={ - "value": SystemConfigOper(db).get(key) + "value": SystemConfigOper().get(key) }) @router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response) def set_setting(key: str, value: Union[list, dict, str, int] = None, - db: Session = Depends(get_db), _: schemas.TokenPayload = Depends(verify_token)): """ 更新系统设置 """ - SystemConfigOper(db).set(key, value) + SystemConfigOper().set(key, value) return schemas.Response(success=True) @@ -185,9 +183,9 @@ def ruletest(title: str, description=subtitle, ) if ruletype == "2": - rule_string = SystemConfigOper(db).get(SystemConfigKey.FilterRules2) + rule_string = SystemConfigOper().get(SystemConfigKey.FilterRules2) else: - rule_string = SystemConfigOper(db).get(SystemConfigKey.FilterRules) + rule_string = SystemConfigOper().get(SystemConfigKey.FilterRules) if not rule_string: return schemas.Response(success=False, message="过滤规则未设置!") diff --git a/app/chain/rss.py b/app/chain/rss.py index a8daac2a..82f41fd3 100644 --- a/app/chain/rss.py +++ b/app/chain/rss.py @@ -30,7 +30,7 @@ class RssChain(ChainBase): super().__init__(db) self.rssoper = RssOper(self._db) self.sites = SitesHelper() - self.systemconfig = SystemConfigOper(self._db) + self.systemconfig = SystemConfigOper() self.downloadchain = DownloadChain(self._db) self.message = MessageHelper() diff --git a/app/chain/search.py b/app/chain/search.py index b84c3d68..9f49ef71 100644 --- a/app/chain/search.py +++ b/app/chain/search.py @@ -29,7 +29,7 @@ class SearchChain(ChainBase): super().__init__(db) self.siteshelper = SitesHelper() self.progress = ProgressHelper() - self.systemconfig = SystemConfigOper(self._db) + self.systemconfig = SystemConfigOper() self.torrenthelper = TorrentHelper() def search_by_tmdbid(self, tmdbid: int, mtype: MediaType = None, area: str = "title") -> List[Context]: diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index 2c7fe576..03ba0a83 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -33,7 +33,7 @@ class SubscribeChain(ChainBase): self.subscribeoper = SubscribeOper(self._db) self.torrentschain = TorrentsChain() self.message = MessageHelper() - self.systemconfig = SystemConfigOper(self._db) + self.systemconfig = SystemConfigOper() def add(self, title: str, year: str, mtype: MediaType = None, diff --git a/app/chain/torrents.py b/app/chain/torrents.py index cc8eb0d4..d9b561c4 100644 --- a/app/chain/torrents.py +++ b/app/chain/torrents.py @@ -1,12 +1,12 @@ from typing import Dict, List, Union from cachetools import cached, TTLCache -from requests import Session from app.chain import ChainBase from app.core.config import settings from app.core.context import TorrentInfo, Context, MediaInfo from app.core.metainfo import MetaInfo +from app.db import SessionFactory from app.db.systemconfig_oper import SystemConfigOper from app.helper.sites import SitesHelper from app.log import logger @@ -23,10 +23,11 @@ class TorrentsChain(ChainBase, metaclass=Singleton): _cache_file = "__torrents_cache__" - def __init__(self, db: Session = None): - super().__init__(db) + def __init__(self): + self._db = SessionFactory() + super().__init__(self._db) self.siteshelper = SitesHelper() - self.systemconfig = SystemConfigOper(self._db) + self.systemconfig = SystemConfigOper() def remote_refresh(self, channel: MessageChannel, userid: Union[str, int] = None): """ diff --git a/app/command.py b/app/command.py index 0e93d487..a5e3786e 100644 --- a/app/command.py +++ b/app/command.py @@ -13,7 +13,7 @@ from app.chain.transfer import TransferChain from app.core.event import Event as ManagerEvent from app.core.event import eventmanager, EventManager from app.core.plugin import PluginManager -from app.db import ScopedSession +from app.db import SessionFactory from app.log import logger from app.schemas.types import EventType, MessageChannel from app.utils.object import ObjectUtils @@ -41,7 +41,7 @@ class Command(metaclass=Singleton): def __init__(self): # 数据库连接 - self._db = ScopedSession() + self._db = SessionFactory() # 事件管理器 self.eventmanager = EventManager() # 插件管理器 diff --git a/app/db/init.py b/app/db/init.py index 6013bcc8..c902c1fb 100644 --- a/app/db/init.py +++ b/app/db/init.py @@ -6,7 +6,7 @@ from alembic.config import Config from app.core.config import settings from app.core.security import get_password_hash -from app.db import Engine, ScopedSession +from app.db import Engine, SessionFactory from app.db.models import Base from app.db.models.user import User from app.log import logger @@ -22,7 +22,7 @@ def init_db(): # 全量建表 Base.metadata.create_all(bind=Engine) # 初始化超级管理员 - db = ScopedSession() + db = SessionFactory() user = User.get_by_name(db=db, name=settings.SUPERUSER) if not user: user = User( diff --git a/app/db/systemconfig_oper.py b/app/db/systemconfig_oper.py index 204b3e02..64fcca0f 100644 --- a/app/db/systemconfig_oper.py +++ b/app/db/systemconfig_oper.py @@ -1,9 +1,7 @@ import json from typing import Any, Union -from sqlalchemy.orm import Session - -from app.db import DbOper +from app.db import DbOper, SessionFactory from app.db.models.systemconfig import SystemConfig from app.schemas.types import SystemConfigKey from app.utils.object import ObjectUtils @@ -14,11 +12,12 @@ class SystemConfigOper(DbOper, metaclass=Singleton): # 配置对象 __SYSTEMCONF: dict = {} - def __init__(self, db: Session = None): + def __init__(self): """ 加载配置到内存 """ - super().__init__(db) + self._db = SessionFactory() + super().__init__(self._db) for item in SystemConfig.list(self._db): if ObjectUtils.is_obj(item.value): self.__SYSTEMCONF[item.key] = json.loads(item.value) diff --git a/app/plugins/speedlimiter/__init__.py b/app/plugins/speedlimiter/__init__.py index 2f187adb..4ba885ff 100644 --- a/app/plugins/speedlimiter/__init__.py +++ b/app/plugins/speedlimiter/__init__.py @@ -592,7 +592,8 @@ class SpeedLimiter(_PluginBase): for allow_ipv6 in allow_ipv6s: if ipaddr in ipaddress.ip_network(allow_ipv6, strict=False): return True - except Exception: + except Exception as err: + print(str(err)) return False return False diff --git a/app/scheduler.py b/app/scheduler.py index a0587669..3962c0bc 100644 --- a/app/scheduler.py +++ b/app/scheduler.py @@ -12,7 +12,7 @@ from app.chain.rss import RssChain from app.chain.subscribe import SubscribeChain from app.chain.transfer import TransferChain from app.core.config import settings -from app.db import ScopedSession +from app.db import SessionFactory from app.log import logger from app.utils.singleton import Singleton from app.utils.timer import TimerUtils @@ -40,7 +40,7 @@ class Scheduler(metaclass=Singleton): def __init__(self): # 数据库连接 - self._db = ScopedSession() + self._db = SessionFactory() # 调试模式不启动定时服务 if settings.DEV: return diff --git a/app/utils/http.py b/app/utils/http.py index 78c96d53..7f032a17 100644 --- a/app/utils/http.py +++ b/app/utils/http.py @@ -5,8 +5,6 @@ import urllib3 from requests import Session, Response from urllib3.exceptions import InsecureRequestWarning -from app.utils.ip import IpUtils - urllib3.disable_warnings(InsecureRequestWarning)