fix dboper

This commit is contained in:
jxxghp 2023-06-15 07:12:59 +08:00
parent 7506f39258
commit ab4895ff85
13 changed files with 106 additions and 89 deletions

View File

@ -6,8 +6,8 @@ from lxml import etree
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.db.siteicons import SiteIcons from app.db.siteicon_oper import SiteIconOper
from app.db.sites import Sites from app.db.site_oper import SiteOper
from app.helper.cookiecloud import CookieCloudHelper from app.helper.cookiecloud import CookieCloudHelper
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper
from app.log import logger from app.log import logger
@ -21,8 +21,8 @@ class CookieCloudChain(ChainBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.sites = Sites() self.siteoper = SiteOper()
self.siteicons = SiteIcons() self.siteiconoper = SiteIconOper()
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
self.cookiecloud = CookieCloudHelper( self.cookiecloud = CookieCloudHelper(
server=settings.COOKIECLOUD_HOST, server=settings.COOKIECLOUD_HOST,
@ -45,13 +45,13 @@ class CookieCloudChain(ChainBase):
for domain, cookie in cookies.items(): for domain, cookie in cookies.items():
# 获取站点信息 # 获取站点信息
indexer = self.siteshelper.get_indexer(domain) indexer = self.siteshelper.get_indexer(domain)
if self.sites.exists(domain): if self.siteoper.exists(domain):
# 更新站点Cookie # 更新站点Cookie
self.sites.update_cookie(domain=domain, cookies=cookie) self.siteoper.update_cookie(domain=domain, cookies=cookie)
_update_count += 1 _update_count += 1
elif indexer: elif indexer:
# 新增站点 # 新增站点
self.sites.add(name=indexer.get("name"), self.siteoper.add(name=indexer.get("name"),
url=indexer.get("domain"), url=indexer.get("domain"),
domain=domain, domain=domain,
cookie=cookie) cookie=cookie)
@ -62,7 +62,7 @@ class CookieCloudChain(ChainBase):
cookie=cookie, cookie=cookie,
ua=settings.USER_AGENT) ua=settings.USER_AGENT)
if icon_url: if icon_url:
self.siteicons.update_icon(name=indexer.get("name"), self.siteiconoper.update_icon(name=indexer.get("name"),
domain=domain, domain=domain,
icon_url=icon_url, icon_url=icon_url,
icon_base64=icon_base64) icon_base64=icon_base64)

View File

@ -1,5 +1,5 @@
from app.chain import ChainBase from app.chain import ChainBase
from app.db.sites import Sites from app.db.site_oper import SiteOper
class SiteManageChain(ChainBase): class SiteManageChain(ChainBase):
@ -7,17 +7,17 @@ class SiteManageChain(ChainBase):
站点远程管理处理链 站点远程管理处理链
""" """
_sites: Sites = None _sites: SiteOper = None
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._sites = Sites() self._siteoper = SiteOper()
def process(self): def process(self):
""" """
查询所有站点发送消息 查询所有站点发送消息
""" """
site_list = self._sites.list() site_list = self._siteoper.list()
if not site_list: if not site_list:
self.post_message(title="没有维护任何站点信息!") self.post_message(title="没有维护任何站点信息!")
title = f"共有 {len(site_list)} 个站点,回复 `/site_disable` `[id]` 禁用站点,回复 `/site_enable` `[id]` 启用站点:" title = f"共有 {len(site_list)} 个站点,回复 `/site_disable` `[id]` 禁用站点,回复 `/site_enable` `[id]` 启用站点:"
@ -44,12 +44,12 @@ class SiteManageChain(ChainBase):
if not arg_str.isdigit(): if not arg_str.isdigit():
return return
site_id = int(arg_str) site_id = int(arg_str)
site = self._sites.get(site_id) site = self._siteoper.get(site_id)
if not site: if not site:
self.post_message(title=f"站点编号 {site_id} 不存在!") self.post_message(title=f"站点编号 {site_id} 不存在!")
return return
# 禁用站点 # 禁用站点
self._sites.update(site_id, { self._siteoper.update(site_id, {
"is_active": False "is_active": False
}) })
# 重新发送消息 # 重新发送消息
@ -65,12 +65,12 @@ class SiteManageChain(ChainBase):
if not arg_str.isdigit(): if not arg_str.isdigit():
return return
site_id = int(arg_str) site_id = int(arg_str)
site = self._sites.get(site_id) site = self._siteoper.get(site_id)
if not site: if not site:
self.post_message(title=f"站点编号 {site_id} 不存在!") self.post_message(title=f"站点编号 {site_id} 不存在!")
return return
# 禁用站点 # 禁用站点
self._sites.update(site_id, { self._siteoper.update(site_id, {
"is_active": True "is_active": True
}) })
# 重新发送消息 # 重新发送消息

View File

@ -6,7 +6,7 @@ from app.chain.search import SearchChain
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.core.context import TorrentInfo, Context, MediaInfo from app.core.context import TorrentInfo, Context, MediaInfo
from app.core.config import settings from app.core.config import settings
from app.db.subscribes import Subscribes from app.db.subscribe_oper import SubscribeOper
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper
from app.log import logger from app.log import logger
from app.schemas.context import NotExistMediaInfo from app.schemas.context import NotExistMediaInfo
@ -26,7 +26,7 @@ class SubscribeChain(ChainBase):
super().__init__() super().__init__()
self.downloadchain = DownloadChain() self.downloadchain = DownloadChain()
self.searchchain = SearchChain() self.searchchain = SearchChain()
self.subscribes = Subscribes() self.subscribehelper = SubscribeOper()
self.siteshelper = SitesHelper() self.siteshelper = SitesHelper()
def process(self, title: str, year: str, def process(self, title: str, year: str,
@ -89,7 +89,7 @@ class SubscribeChain(ChainBase):
'lack_episode': kwargs.get('total_episode') 'lack_episode': kwargs.get('total_episode')
}) })
# 添加订阅 # 添加订阅
sid, err_msg = self.subscribes.add(mediainfo, season=season, **kwargs) sid, err_msg = self.subscribehelper.add(mediainfo, season=season, **kwargs)
if not sid: if not sid:
logger.error(f'{mediainfo.title_year} {err_msg}') logger.error(f'{mediainfo.title_year} {err_msg}')
# 发回原用户 # 发回原用户
@ -115,15 +115,15 @@ class SubscribeChain(ChainBase):
:return: 更新订阅状态为R或删除订阅 :return: 更新订阅状态为R或删除订阅
""" """
if sid: if sid:
subscribes = [self.subscribes.get(sid)] subscribes = [self.subscribehelper.get(sid)]
else: else:
subscribes = self.subscribes.list(state) subscribes = self.subscribehelper.list(state)
# 遍历订阅 # 遍历订阅
for subscribe in subscribes: for subscribe in subscribes:
logger.info(f'开始搜索订阅,标题:{subscribe.name} ...') logger.info(f'开始搜索订阅,标题:{subscribe.name} ...')
# 如果状态为N则更新为R # 如果状态为N则更新为R
if subscribe.state == 'N': if subscribe.state == 'N':
self.subscribes.update(subscribe.id, {'state': 'R'}) self.subscribehelper.update(subscribe.id, {'state': 'R'})
# 生成元数据 # 生成元数据
meta = MetaInfo(subscribe.name) meta = MetaInfo(subscribe.name)
meta.year = subscribe.year meta.year = subscribe.year
@ -138,7 +138,7 @@ class SubscribeChain(ChainBase):
exist_flag, no_exists = self.downloadchain.get_no_exists_info(meta=meta, mediainfo=mediainfo) exist_flag, no_exists = self.downloadchain.get_no_exists_info(meta=meta, mediainfo=mediainfo)
if exist_flag: if exist_flag:
logger.info(f'{mediainfo.title_year} 媒体库中已存在,完成订阅') logger.info(f'{mediainfo.title_year} 媒体库中已存在,完成订阅')
self.subscribes.delete(subscribe.id) self.subscribehelper.delete(subscribe.id)
# 发送通知 # 发送通知
self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅', self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅',
image=mediainfo.get_message_image()) image=mediainfo.get_message_image())
@ -165,7 +165,7 @@ class SubscribeChain(ChainBase):
if downloads and not lefts: if downloads and not lefts:
# 全部下载完成 # 全部下载完成
logger.info(f'{mediainfo.title_year} 下载完成,完成订阅') logger.info(f'{mediainfo.title_year} 下载完成,完成订阅')
self.subscribes.delete(subscribe.id) self.subscribehelper.delete(subscribe.id)
# 发送通知 # 发送通知
self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅', self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅',
image=mediainfo.get_message_image()) image=mediainfo.get_message_image())
@ -224,7 +224,7 @@ class SubscribeChain(ChainBase):
从缓存中匹配订阅并自动下载 从缓存中匹配订阅并自动下载
""" """
# 所有订阅 # 所有订阅
subscribes = self.subscribes.list('R') subscribes = self.subscribehelper.list('R')
# 遍历订阅 # 遍历订阅
for subscribe in subscribes: for subscribe in subscribes:
logger.info(f'开始匹配订阅,标题:{subscribe.name} ...') logger.info(f'开始匹配订阅,标题:{subscribe.name} ...')
@ -242,7 +242,7 @@ class SubscribeChain(ChainBase):
exist_flag, no_exists = self.downloadchain.get_no_exists_info(meta=meta, mediainfo=mediainfo) exist_flag, no_exists = self.downloadchain.get_no_exists_info(meta=meta, mediainfo=mediainfo)
if exist_flag: if exist_flag:
logger.info(f'{mediainfo.title_year} 媒体库中已存在,完成订阅') logger.info(f'{mediainfo.title_year} 媒体库中已存在,完成订阅')
self.subscribes.delete(subscribe.id) self.subscribehelper.delete(subscribe.id)
# 发送通知 # 发送通知
self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅', self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅',
image=mediainfo.get_message_image()) image=mediainfo.get_message_image())
@ -278,7 +278,7 @@ class SubscribeChain(ChainBase):
if downloads and not lefts: if downloads and not lefts:
# 全部下载完成 # 全部下载完成
logger.info(f'{mediainfo.title_year} 下载完成,完成订阅') logger.info(f'{mediainfo.title_year} 下载完成,完成订阅')
self.subscribes.delete(subscribe.id) self.subscribehelper.delete(subscribe.id)
# 发送通知 # 发送通知
self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅', self.post_message(title=f'{mediainfo.title_year}{meta.season} 已完成订阅',
image=mediainfo.get_message_image()) image=mediainfo.get_message_image())
@ -291,7 +291,7 @@ class SubscribeChain(ChainBase):
left_episodes = season_info.get('episodes') left_episodes = season_info.get('episodes')
logger.info(f'{mediainfo.title_year}{season} 未下载完整,' logger.info(f'{mediainfo.title_year}{season} 未下载完整,'
f'更新缺失集数为{len(left_episodes)} ...') f'更新缺失集数为{len(left_episodes)} ...')
self.subscribes.update(subscribe.id, { self.subscribehelper.update(subscribe.id, {
"lack_episode": len(left_episodes) "lack_episode": len(left_episodes)
}) })
@ -299,7 +299,7 @@ class SubscribeChain(ChainBase):
""" """
查询订阅并发送消息 查询订阅并发送消息
""" """
subscribes = self.subscribes.list() subscribes = self.subscribehelper.list()
if not subscribes: if not subscribes:
self.post_message(title='没有任何订阅!') self.post_message(title='没有任何订阅!')
return return
@ -328,12 +328,12 @@ class SubscribeChain(ChainBase):
if not arg_str.isdigit(): if not arg_str.isdigit():
return return
subscribe_id = int(arg_str) subscribe_id = int(arg_str)
subscribe = self.subscribes.get(subscribe_id) subscribe = self.subscribehelper.get(subscribe_id)
if not subscribe: if not subscribe:
self.post_message(title=f"订阅编号 {subscribe_id} 不存在!") self.post_message(title=f"订阅编号 {subscribe_id} 不存在!")
return return
# 删除订阅 # 删除订阅
self.subscribes.delete(subscribe_id) self.subscribehelper.delete(subscribe_id)
# 重新发送消息 # 重新发送消息
self.list() self.list()

View File

@ -1,7 +1,7 @@
import traceback import traceback
from typing import List, Any from typing import List, Any
from app.db.systemconfigs import SystemConfigs from app.db.systemconfig_oper import SystemConfigOper
from app.helper.module import ModuleHelper from app.helper.module import ModuleHelper
from app.log import logger from app.log import logger
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
@ -11,7 +11,7 @@ class PluginManager(metaclass=Singleton):
""" """
插件管理器 插件管理器
""" """
systemconfigs: SystemConfigs = None systemconfigs: SystemConfigOper = None
# 插件列表 # 插件列表
_plugins: dict = {} _plugins: dict = {}
@ -24,7 +24,7 @@ class PluginManager(metaclass=Singleton):
self.init_config() self.init_config()
def init_config(self): def init_config(self):
self.systemconfigs = SystemConfigs() self.systemconfigs = SystemConfigOper()
# 停止已有插件 # 停止已有插件
self.stop() self.stop()
# 启动插件 # 启动插件

View File

@ -1,5 +1,5 @@
from sqlalchemy import create_engine, QueuePool from sqlalchemy import create_engine, QueuePool
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, Session
from app.core.config import settings from app.core.config import settings
@ -27,3 +27,15 @@ def get_db():
finally: finally:
if db: if db:
db.close() db.close()
class DbOper:
_db: Session = None
def __init__(self, _db=SessionLocal()):
self._db = _db
def __del__(self):
if self._db:
self._db.close()

View File

@ -11,7 +11,6 @@ class Base:
def create(self, db): def create(self, db):
db.add(self) db.add(self)
db.commit() db.commit()
db.refresh(self)
return self return self
@classmethod @classmethod

35
app/db/plugindata_oper.py Normal file
View File

@ -0,0 +1,35 @@
import json
from typing import Any
from app.db import DbOper
from app.db.models import Base
from app.db.models.plugin import PluginData
from app.utils.object import ObjectUtils
class PluginDataOper(DbOper):
"""
插件数据管理
"""
def save(self, plugin_id: str, key: str, value: Any) -> Base:
"""
保存插件数据
:param plugin_id: 插件id
:param key: 数据key
:param value: 数据值
"""
if ObjectUtils.is_obj(value):
value = json.dumps(value)
plugin = PluginData(plugin_id=plugin_id, key=key, value=value)
return plugin.create(self._db)
def get_data(self, key: str) -> Any:
"""
获取插件数据
:param key: 数据key
"""
data = PluginData.get_plugin_data_by_key(self._db, self.__class__.__name__, key)
if ObjectUtils.is_obj(data):
return json.load(data)
return data

View File

@ -1,19 +1,13 @@
from typing import Tuple, List from typing import Tuple, List
from sqlalchemy.orm import Session from app.db import DbOper
from app.db import SessionLocal
from app.db.models.site import Site from app.db.models.site import Site
class Sites: class SiteOper(DbOper):
""" """
站点管理 站点管理
""" """
_db: Session = None
def __init__(self, _db=SessionLocal()):
self._db = _db
def add(self, **kwargs) -> Tuple[bool, str]: def add(self, **kwargs) -> Tuple[bool, str]:
""" """

View File

@ -1,19 +1,13 @@
from typing import List from typing import List
from sqlalchemy.orm import Session from app.db import DbOper
from app.db import SessionLocal
from app.db.models.siteicon import SiteIcon from app.db.models.siteicon import SiteIcon
class SiteIcons: class SiteIconOper(DbOper):
""" """
站点管理 站点管理
""" """
_db: Session = None
def __init__(self, _db=SessionLocal()):
self._db = _db
def list(self) -> List[SiteIcon]: def list(self) -> List[SiteIcon]:
""" """

View File

@ -1,20 +1,14 @@
from typing import Tuple, List from typing import Tuple, List
from sqlalchemy.orm import Session
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.db import SessionLocal from app.db import DbOper
from app.db.models.subscribe import Subscribe from app.db.models.subscribe import Subscribe
class Subscribes: class SubscribeOper(DbOper):
""" """
订阅管理 订阅管理
""" """
_db: Session = None
def __init__(self, _db=SessionLocal()):
self._db = _db
def add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]: def add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]:
""" """

View File

@ -1,25 +1,22 @@
import json import json
from typing import Any, Union from typing import Any, Union
from sqlalchemy.orm import Session from app.db import DbOper
from app.db import SessionLocal
from app.db.models.systemconfig import SystemConfig from app.db.models.systemconfig import SystemConfig
from app.utils.object import ObjectUtils from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
from app.utils.types import SystemConfigKey from app.utils.types import SystemConfigKey
class SystemConfigs(metaclass=Singleton): class SystemConfigOper(DbOper, metaclass=Singleton):
# 配置对象 # 配置对象
__SYSTEMCONF: dict = {} __SYSTEMCONF: dict = {}
_db: Session = None
def __init__(self, _db=SessionLocal()): def __init__(self):
""" """
加载配置到内存 加载配置到内存
""" """
self._db = _db super().__init__()
for item in SystemConfig.list(self._db): for item in SystemConfig.list(self._db):
if ObjectUtils.is_obj(item.value): if ObjectUtils.is_obj(item.value):
self.__SYSTEMCONF[item.key] = json.loads(item.value) self.__SYSTEMCONF[item.key] = json.loads(item.value)

Binary file not shown.

View File

@ -1,15 +1,12 @@
import json
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.db import SessionLocal
from app.db.models import Base from app.db.models import Base
from app.db.models.plugin import PluginData from app.db.plugindata_oper import PluginDataOper
from app.db.systemconfigs import SystemConfigs from app.db.systemconfig_oper import SystemConfigOper
from app.utils.object import ObjectUtils
class PluginChian(ChainBase): class PluginChian(ChainBase):
@ -39,8 +36,9 @@ class _PluginBase(metaclass=ABCMeta):
plugin_desc: str = "" plugin_desc: str = ""
def __init__(self): def __init__(self):
self.db = SessionLocal() self.plugindata = PluginDataOper()
self.chain = PluginChian() self.chain = PluginChian()
self.systemconfig = SystemConfigOper()
@abstractmethod @abstractmethod
def init_plugin(self, config: dict = None): def init_plugin(self, config: dict = None):
@ -65,7 +63,7 @@ class _PluginBase(metaclass=ABCMeta):
""" """
if not plugin_id: if not plugin_id:
plugin_id = self.__class__.__name__ plugin_id = self.__class__.__name__
return SystemConfigs().set(f"plugin.{plugin_id}", config) return self.systemconfig.set(f"plugin.{plugin_id}", config)
def get_config(self, plugin_id: str = None) -> Any: def get_config(self, plugin_id: str = None) -> Any:
""" """
@ -74,7 +72,7 @@ class _PluginBase(metaclass=ABCMeta):
""" """
if not plugin_id: if not plugin_id:
plugin_id = self.__class__.__name__ plugin_id = self.__class__.__name__
return SystemConfigs().get(f"plugin.{plugin_id}") return self.systemconfig.get(f"plugin.{plugin_id}")
def get_data_path(self, plugin_id: str = None) -> Path: def get_data_path(self, plugin_id: str = None) -> Path:
""" """
@ -93,17 +91,11 @@ class _PluginBase(metaclass=ABCMeta):
:param key: 数据key :param key: 数据key
:param value: 数据值 :param value: 数据值
""" """
if ObjectUtils.is_obj(value): return self.plugindata.save(self.__class__.__name__, key, value)
value = json.dumps(value)
plugin = PluginData(plugin_id=self.__class__.__name__, key=key, value=value)
return plugin.create(self.db)
def get_data(self, key: str) -> Any: def get_data(self, key: str) -> Any:
""" """
获取插件数据 获取插件数据
:param key: 数据key :param key: 数据key
""" """
data = PluginData.get_plugin_data_by_key(self.db, self.__class__.__name__, key) return self.plugindata.get_data(key)
if ObjectUtils.is_obj(data):
return json.load(data)
return data