diff --git a/app/chain/__init__.py b/app/chain/__init__.py index adf1e986..c087ea0e 100644 --- a/app/chain/__init__.py +++ b/app/chain/__init__.py @@ -1,9 +1,11 @@ +import pickle import traceback from pathlib import Path from typing import Optional, Any, Tuple, List, Set, Union, Dict from ruamel.yaml import CommentedMap +from app.core.config import settings from app.core.context import Context from app.core.context import MediaInfo, TorrentInfo from app.core.event import EventManager @@ -28,6 +30,23 @@ class ChainBase(AbstractSingleton, metaclass=Singleton): self.modulemanager = ModuleManager() self.eventmanager = EventManager() + @staticmethod + def __load_cache(filename: str) -> Any: + """ + 从本地加载缓存 + """ + cache_path = settings.TEMP_PATH / filename + if cache_path.exists(): + return pickle.load(cache_path.open('rb')) + return None + + @staticmethod + def __save_cache(cache: Any, filename: str) -> None: + """ + 保存缓存到本地 + """ + pickle.dump(cache, (settings.TEMP_PATH / filename).open('wb')) + def run_module(self, method: str, *args, **kwargs) -> Any: """ 运行包含该方法的所有模块,然后返回结果 diff --git a/app/chain/message.py b/app/chain/message.py index 25309fb9..02f74a68 100644 --- a/app/chain/message.py +++ b/app/chain/message.py @@ -16,7 +16,7 @@ class MessageChain(ChainBase): 外来消息处理链 """ # 缓存的用户数据 {userid: {type: str, items: list}} - _user_cache: Dict[str, dict] = {} + _cache_file = "__user_messages__" # 每页数据量 _page_size: int = 8 # 当前页面 @@ -58,6 +58,9 @@ class MessageChain(ChainBase): if not text: logger.debug(f'未识别到消息内容::{body}{form}{args}') return + # 加载缓存 + user_cache: Dict[str, dict] = self.__load_cache(self._cache_file) or {} + # 处理消息 logger.info(f'收到用户消息内容,用户:{userid},内容:{text}') if text.startswith('/'): # 执行命令 @@ -72,7 +75,7 @@ class MessageChain(ChainBase): elif text.isdigit(): # 缓存 - cache_data: dict = self._user_cache.get(userid) + cache_data: dict = user_cache.get(userid) # 选择项目 if not cache_data \ or not cache_data.get('items') \ @@ -125,7 +128,7 @@ class MessageChain(ChainBase): # 搜索结果排序 contexts = self.torrenthelper.sort_torrents(contexts) # 更新缓存 - self._user_cache[userid] = { + user_cache[userid] = { "type": "Torrent", "items": contexts } @@ -201,7 +204,7 @@ class MessageChain(ChainBase): elif text.lower() == "p": # 上一页 - cache_data: dict = self._user_cache.get(userid) + cache_data: dict = user_cache.get(userid) if not cache_data: # 没有缓存 self.post_message(Notification( @@ -240,7 +243,7 @@ class MessageChain(ChainBase): elif text.lower() == "n": # 下一页 - cache_data: dict = self._user_cache.get(userid) + cache_data: dict = user_cache.get(userid) if not cache_data: # 没有缓存 self.post_message(Notification( @@ -296,7 +299,7 @@ class MessageChain(ChainBase): logger.info(f"搜索到 {len(medias)} 条相关媒体信息") # 记录当前状态 self._current_meta = meta - self._user_cache[userid] = { + user_cache[userid] = { 'type': action, 'items': medias } @@ -307,6 +310,8 @@ class MessageChain(ChainBase): title=meta.name, items=medias[:self._page_size], userid=userid, total=len(medias)) + # 保存缓存 + self.__save_cache(user_cache, self._cache_file) def __post_medias_message(self, channel: MessageChannel, title: str, items: list, userid: str, total: int): diff --git a/app/chain/subscribe.py b/app/chain/subscribe.py index f41bf308..e12281f6 100644 --- a/app/chain/subscribe.py +++ b/app/chain/subscribe.py @@ -1,8 +1,6 @@ import json -import pickle import re from datetime import datetime -from pathlib import Path from typing import Dict, List, Optional, Union, Tuple from app.chain import ChainBase @@ -27,7 +25,7 @@ class SubscribeChain(ChainBase): 订阅管理处理链 """ - __cache_path: Path = None + _cache_file = "__torrents_cache__" def __init__(self): super().__init__() @@ -38,9 +36,6 @@ class SubscribeChain(ChainBase): self.message = MessageHelper() self.systemconfig = SystemConfigOper() - # 缓存路径 - self.__cache_path = settings.TEMP_PATH / "__torrents_cache__" - def add(self, title: str, year: str, mtype: MediaType = None, tmdbid: int = None, @@ -302,7 +297,7 @@ class SubscribeChain(ChainBase): 刷新站点最新资源 """ # 读取缓存 - torrents_cache: Dict[str, List[Context]] = self.__load_cache() + torrents_cache: Dict[str, List[Context]] = self.__load_cache(self._cache_file) or {} # 所有站点索引 indexers = self.siteshelper.get_indexers() @@ -368,21 +363,7 @@ class SubscribeChain(ChainBase): # 从缓存中匹配订阅 self.__match(torrents_cache) # 保存缓存到本地 - self.__save_cache(torrents_cache) - - def __load_cache(self) -> Dict[str, List[Context]]: - """ - 从本地加载缓存 - """ - if self.__cache_path.exists(): - return pickle.load(self.__cache_path.open('rb')) or {} - return {} - - def __save_cache(self, cache: Dict[str, List[Context]]): - """ - 保存缓存到本地 - """ - pickle.dump(cache, self.__cache_path.open('wb')) + self.__save_cache(torrents_cache, self._cache_file) def __match(self, torrents_cache: Dict[str, List[Context]]): """