diff --git a/app/db/__init__.py b/app/db/__init__.py index 19649177..909b8d2b 100644 --- a/app/db/__init__.py +++ b/app/db/__init__.py @@ -1,7 +1,10 @@ +from typing import Any, Self, List from typing import Tuple, Optional, Generator from sqlalchemy import create_engine, QueuePool -from sqlalchemy.orm import sessionmaker, Session, scoped_session +from sqlalchemy import inspect +from sqlalchemy.orm import declared_attr +from sqlalchemy.orm import sessionmaker, Session, scoped_session, as_declarative from app.core.config import settings @@ -135,6 +138,52 @@ def db_query(func): return wrapper +@as_declarative() +class Base: + id: Any + __name__: str + + @db_update + def create(self, db: Session): + db.add(self) + + @classmethod + @db_query + def get(cls, db: Session, rid: int) -> Self: + return db.query(cls).filter(cls.id == rid).first() + + @db_update + def update(self, db: Session, payload: dict): + payload = {k: v for k, v in payload.items() if v is not None} + for key, value in payload.items(): + setattr(self, key, value) + if inspect(self).detached: + db.add(self) + + @classmethod + @db_update + def delete(cls, db: Session, rid): + db.query(cls).filter(cls.id == rid).delete() + + @classmethod + @db_update + def truncate(cls, db: Session): + db.query(cls).delete() + + @classmethod + @db_query + def list(cls, db: Session) -> List[Self]: + result = db.query(cls).all() + return list(result) + + def to_dict(self): + return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} + + @declared_attr + def __tablename__(self) -> str: + return self.__name__.lower() + + class DbOper: """ 数据库操作基类 diff --git a/app/db/init.py b/app/db/init.py index d21b1090..bd3444f4 100644 --- a/app/db/init.py +++ b/app/db/init.py @@ -1,15 +1,10 @@ -import importlib -from pathlib import Path - from alembic.command import upgrade from alembic.config import Config from app.core.config import settings from app.core.security import get_password_hash -from app.db import Engine, SessionFactory -from app.db.models import Base -from app.db.models.user import User -from app.helper.module import ModuleHelper +from app.db import Engine, SessionFactory, Base +from app.db.models import * from app.log import logger @@ -17,21 +12,18 @@ def init_db(): """ 初始化数据库 """ - # 导入模块,避免建表缺失 - models_path = Path(__file__).with_name("models") - ModuleHelper.dynamic_import_all_modules(models_path, "app.db.models") # 全量建表 Base.metadata.create_all(bind=Engine) # 初始化超级管理员 with SessionFactory() as db: - user = User.get_by_name(db=db, name=settings.SUPERUSER) - if not user: - user = User( + _user = User.get_by_name(db=db, name=settings.SUPERUSER) + if not _user: + _user = User( name=settings.SUPERUSER, hashed_password=get_password_hash(settings.SUPERUSER_PASSWORD), is_superuser=True, ) - user.create(db) + _user.create(db) def update_db(): diff --git a/app/db/models/__init__.py b/app/db/models/__init__.py index 067d2559..57ddee03 100644 --- a/app/db/models/__init__.py +++ b/app/db/models/__init__.py @@ -1,52 +1,9 @@ -from typing import Any, Self, List - -from sqlalchemy import inspect -from sqlalchemy.orm import as_declarative, declared_attr, Session - -from app.db import db_update, db_query - - -@as_declarative() -class Base: - id: Any - __name__: str - - @db_update - def create(self, db: Session): - db.add(self) - - @classmethod - @db_query - def get(cls, db: Session, rid: int) -> Self: - return db.query(cls).filter(cls.id == rid).first() - - @db_update - def update(self, db: Session, payload: dict): - payload = {k: v for k, v in payload.items() if v is not None} - for key, value in payload.items(): - setattr(self, key, value) - if inspect(self).detached: - db.add(self) - - @classmethod - @db_update - def delete(cls, db: Session, rid): - db.query(cls).filter(cls.id == rid).delete() - - @classmethod - @db_update - def truncate(cls, db: Session): - db.query(cls).delete() - - @classmethod - @db_query - def list(cls, db: Session) -> List[Self]: - result = db.query(cls).all() - return list(result) - - def to_dict(self): - return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} - - @declared_attr - def __tablename__(self) -> str: - return self.__name__.lower() +from .downloadhistory import DownloadHistory, DownloadFiles +from .mediaserver import MediaServerItem +from .plugindata import PluginData +from .site import Site +from .siteicon import SiteIcon +from .subscribe import Subscribe +from .systemconfig import SystemConfig +from .transferhistory import TransferHistory +from .user import User diff --git a/app/db/models/downloadhistory.py b/app/db/models/downloadhistory.py index d37ab652..1e796877 100644 --- a/app/db/models/downloadhistory.py +++ b/app/db/models/downloadhistory.py @@ -1,8 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db import db_query -from app.db.models import Base, db_update +from app.db import db_query, db_update, Base class DownloadHistory(Base): diff --git a/app/db/models/mediaserver.py b/app/db/models/mediaserver.py index df28d072..1690bc04 100644 --- a/app/db/models/mediaserver.py +++ b/app/db/models/mediaserver.py @@ -4,8 +4,7 @@ from typing import Optional from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db import db_query -from app.db.models import Base, db_update +from app.db import db_query, db_update, Base class MediaServerItem(Base): diff --git a/app/db/models/plugindata.py b/app/db/models/plugindata.py index f6346728..d10ee14c 100644 --- a/app/db/models/plugindata.py +++ b/app/db/models/plugindata.py @@ -1,8 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db import db_query -from app.db.models import Base, db_update +from app.db import db_query, db_update, Base class PluginData(Base): diff --git a/app/db/models/site.py b/app/db/models/site.py index e79fe413..297e8813 100644 --- a/app/db/models/site.py +++ b/app/db/models/site.py @@ -3,8 +3,7 @@ from datetime import datetime from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db import db_query -from app.db.models import Base, db_update +from app.db import db_query, db_update, Base class Site(Base): diff --git a/app/db/models/siteicon.py b/app/db/models/siteicon.py index 787b86c4..770cd37e 100644 --- a/app/db/models/siteicon.py +++ b/app/db/models/siteicon.py @@ -1,8 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db import db_query -from app.db.models import Base +from app.db import db_query, Base class SiteIcon(Base): diff --git a/app/db/models/subscribe.py b/app/db/models/subscribe.py index b9699e36..187c2c38 100644 --- a/app/db/models/subscribe.py +++ b/app/db/models/subscribe.py @@ -1,8 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db import db_update, db_query -from app.db.models import Base +from app.db import db_query, db_update, Base class Subscribe(Base): diff --git a/app/db/models/systemconfig.py b/app/db/models/systemconfig.py index 5a304c01..d26b9aee 100644 --- a/app/db/models/systemconfig.py +++ b/app/db/models/systemconfig.py @@ -1,8 +1,7 @@ from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy.orm import Session -from app.db import db_update, db_query -from app.db.models import Base +from app.db import db_query, db_update, Base class SystemConfig(Base): diff --git a/app/db/models/transferhistory.py b/app/db/models/transferhistory.py index a906b8a0..93948008 100644 --- a/app/db/models/transferhistory.py +++ b/app/db/models/transferhistory.py @@ -3,8 +3,7 @@ import time from sqlalchemy import Column, Integer, String, Sequence, Boolean, func from sqlalchemy.orm import Session -from app.db import db_query -from app.db.models import Base, db_update +from app.db import db_query, db_update, Base class TransferHistory(Base): diff --git a/app/db/models/user.py b/app/db/models/user.py index b058c48a..51aad923 100644 --- a/app/db/models/user.py +++ b/app/db/models/user.py @@ -2,8 +2,7 @@ from sqlalchemy import Boolean, Column, Integer, String, Sequence from sqlalchemy.orm import Session from app.core.security import verify_password -from app.db import db_update, db_query -from app.db.models import Base +from app.db import db_query, db_update, Base class User(Base): diff --git a/app/helper/module.py b/app/helper/module.py index f0e9ba6e..e4c7ce33 100644 --- a/app/helper/module.py +++ b/app/helper/module.py @@ -40,7 +40,7 @@ class ModuleHelper: @staticmethod def dynamic_import_all_modules(base_path: Path, package_name: str): """ - 动态导入所有模块到全局对象 + 动态导入目录下所有模块 """ modules = [] # 遍历文件夹,找到所有模块文件 @@ -48,14 +48,5 @@ class ModuleHelper: file_name = file.stem if file_name != "__init__": modules.append(file_name) - # 保存已有的全局对象 - existing_globals = set(globals().keys()) - # 动态导入并添加到全局命名空间 - for module in modules: - full_module_name = f"{package_name}.{module}" - import_module = importlib.import_module(full_module_name) - module_globals = import_module.__dict__ - # 仅导入全局对象中不存在的部分 - new_objects = {name: value for name, value in module_globals.items() if name not in existing_globals} - # 更新全局命名空间 - globals().update(new_objects) + full_module_name = f"{package_name}.{file_name}" + importlib.import_module(full_module_name) diff --git a/database/env.py b/database/env.py index 0192775e..0efb8484 100644 --- a/database/env.py +++ b/database/env.py @@ -5,7 +5,7 @@ from sqlalchemy import pool from alembic import context -from app.db.models import Base +from app.db import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config diff --git a/version.py b/version.py index 36be3e77..ca0b8d10 100644 --- a/version.py +++ b/version.py @@ -1 +1 @@ -APP_VERSION = 'v1.4.8-1' +APP_VERSION = 'v1.4.8-2'