This commit is contained in:
zss 2024-03-28 16:39:34 +08:00
parent ff07841dd6
commit 40d99f1dd5
7 changed files with 142 additions and 7 deletions

View File

@ -1,7 +1,7 @@
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Form
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -15,13 +15,16 @@ from app.db import get_db
from app.db.models.user import User from app.db.models.user import User
from app.log import logger from app.log import logger
from app.utils.web import WebUtils from app.utils.web import WebUtils
from app.utils.otp import OtpUtils
router = APIRouter() router = APIRouter()
@router.post("/access-token", summary="获取token", response_model=schemas.Token) @router.post("/access-token", summary="获取token", response_model=schemas.Token)
async def login_access_token( async def login_access_token(
db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends() db: Session = Depends(get_db),
form_data: OAuth2PasswordRequestForm = Depends(),
otp_password: str = Form(None)
) -> Any: ) -> Any:
""" """
获取认证Token 获取认证Token
@ -30,7 +33,8 @@ async def login_access_token(
user = User.authenticate( user = User.authenticate(
db=db, db=db,
name=form_data.username, name=form_data.username,
password=form_data.password password=form_data.password,
otp_password=otp_password
) )
if not user: if not user:
# 请求协助认证 # 请求协助认证
@ -38,7 +42,7 @@ async def login_access_token(
token = UserChain().user_authenticate(form_data.username, form_data.password) token = UserChain().user_authenticate(form_data.username, form_data.password)
if not token: if not token:
logger.warn(f"用户 {form_data.username} 登录失败!") logger.warn(f"用户 {form_data.username} 登录失败!")
raise HTTPException(status_code=401, detail="用户名或密码不正确") raise HTTPException(status_code=401, detail="用户名、密码、二次校验不正确")
else: else:
logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...") logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...")
# 加入用户信息表 # 加入用户信息表

View File

@ -10,6 +10,7 @@ from app.core.security import get_password_hash
from app.db import get_db from app.db import get_db
from app.db.models.user import User from app.db.models.user import User
from app.db.userauth import get_current_active_superuser, get_current_active_user from app.db.userauth import get_current_active_superuser, get_current_active_user
from app.utils.otp import OtpUtils
router = APIRouter() router = APIRouter()
@ -141,3 +142,34 @@ def read_user_by_id(
detail="用户权限不足" detail="用户权限不足"
) )
return user return user
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
def otp_generate(
current_user: User = Depends(get_current_active_user)
) -> Any:
secret, uri = OtpUtils.generate_secret_key(current_user.name)
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
def otp_judge(
data: dict,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
) -> Any:
uri = data.get("uri")
otp_password = data.get("otpPassword")
if not OtpUtils.is_legal(uri, otp_password):
return schemas.Response(success=False, message="验证码错误")
current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
return schemas.Response(success=True)
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
def otp_disable(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
) -> Any:
current_user.update_otp_by_name(db, current_user.name, False, "")
return schemas.Response(success=True)

View File

@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
from app.core.security import verify_password from app.core.security import verify_password
from app.db import db_query, db_update, Base from app.db import db_query, db_update, Base
from app.utils.otp import OtpUtils
class User(Base): class User(Base):
""" """
@ -23,15 +23,22 @@ class User(Base):
is_superuser = Column(Boolean(), default=False) is_superuser = Column(Boolean(), default=False)
# 头像 # 头像
avatar = Column(String) avatar = Column(String)
# 是否启用otp二次验证
is_otp = Column(Boolean(), default=False)
# otp秘钥
otp_secret = Column(String, default=None)
@staticmethod @staticmethod
@db_query @db_query
def authenticate(db: Session, name: str, password: str): def authenticate(db: Session, name: str, password: str, otp_password: str):
user = db.query(User).filter(User.name == name).first() user = db.query(User).filter(User.name == name).first()
if not user: if not user:
return None return None
if not verify_password(password, str(user.hashed_password)): if not verify_password(password, str(user.hashed_password)):
return None return None
if user.is_otp:
if not otp_password or not OtpUtils.check(user.otp_secret, otp_password):
return None
return user return user
@staticmethod @staticmethod
@ -45,3 +52,14 @@ class User(Base):
if user: if user:
user.delete(db, user.id) user.delete(db, user.id)
return True return True
@db_update
def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str):
user = self.get_by_name(db, name)
if user:
user.update(db, {
'is_otp': otp,
'otp_secret': secret
})
return True
return False

View File

@ -15,6 +15,8 @@ class UserBase(BaseModel):
is_superuser: bool = False is_superuser: bool = False
# 头像 # 头像
avatar: Optional[str] = None avatar: Optional[str] = None
# 是否开启二次验证
is_otp: Optional[bool] = False
# Properties to receive via API on creation # Properties to receive via API on creation

48
app/utils/otp.py Normal file
View File

@ -0,0 +1,48 @@
import pyotp
class OtpUtils:
@staticmethod
def generate_secret_key(username: str) -> (str, str):
try:
secret = pyotp.random_base32()
uri = pyotp.totp.TOTP(secret).provisioning_uri(name='MoviePilot',
issuer_name='MoviePilot(' + username + ')')
return secret, uri
except Exception as err:
print(str(err))
return "", ""
@staticmethod
def is_legal(otp_uri: str, password: str) -> bool:
"""
校验二次验证是否正确
"""
try:
return pyotp.TOTP(pyotp.parse_uri(otp_uri).secret).verify(password)
except Exception as err:
print(str(err))
return False
@staticmethod
def check(secret: str, password: str) -> bool:
"""
校验二次验证是否正确
"""
try:
totp = pyotp.TOTP(secret)
return totp.verify(password)
except Exception as err:
print(str(err))
return False
@staticmethod
def get_secret(otp_uri: str) -> str:
"""
获取uri中的secret
"""
try:
return pyotp.parse_uri(otp_uri).secret
except Exception as err:
print(str(err))
return ""

View File

@ -0,0 +1,30 @@
"""1_0_17
Revision ID: 9cb3993e340e
Revises: d146dea51516
Create Date: 2024-03-28 14:36:35.588392
"""
import contextlib
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '9cb3993e340e'
down_revision = 'd146dea51516'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with contextlib.suppress(Exception):
with op.batch_alter_table("user") as batch_op:
batch_op.add_column(sa.Column('is_otp', sa.BOOLEAN, server_default='0'))
batch_op.add_column(sa.Column('otp_secret', sa.VARCHAR))
# ### end Alembic commands ###
def downgrade() -> None:
pass

View File

@ -54,4 +54,5 @@ parse~=1.19.0
docker~=6.1.3 docker~=6.1.3
cachetools~=5.3.1 cachetools~=5.3.1
fast-bencode~=1.1.3 fast-bencode~=1.1.3
pystray~=0.19.5 pystray~=0.19.5
pypushdeer~=0.0.3