feat #1763
This commit is contained in:
parent
ff07841dd6
commit
40d99f1dd5
@ -1,7 +1,7 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Form
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -15,13 +15,16 @@ from app.db import get_db
|
||||
from app.db.models.user import User
|
||||
from app.log import logger
|
||||
from app.utils.web import WebUtils
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/access-token", summary="获取token", response_model=schemas.Token)
|
||||
async def login_access_token(
|
||||
db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
||||
db: Session = Depends(get_db),
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
otp_password: str = Form(None)
|
||||
) -> Any:
|
||||
"""
|
||||
获取认证Token
|
||||
@ -30,7 +33,8 @@ async def login_access_token(
|
||||
user = User.authenticate(
|
||||
db=db,
|
||||
name=form_data.username,
|
||||
password=form_data.password
|
||||
password=form_data.password,
|
||||
otp_password=otp_password
|
||||
)
|
||||
if not user:
|
||||
# 请求协助认证
|
||||
@ -38,7 +42,7 @@ async def login_access_token(
|
||||
token = UserChain().user_authenticate(form_data.username, form_data.password)
|
||||
if not token:
|
||||
logger.warn(f"用户 {form_data.username} 登录失败!")
|
||||
raise HTTPException(status_code=401, detail="用户名或密码不正确")
|
||||
raise HTTPException(status_code=401, detail="用户名、密码、二次校验不正确")
|
||||
else:
|
||||
logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...")
|
||||
# 加入用户信息表
|
||||
|
@ -10,6 +10,7 @@ from app.core.security import get_password_hash
|
||||
from app.db import get_db
|
||||
from app.db.models.user import User
|
||||
from app.db.userauth import get_current_active_superuser, get_current_active_user
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -141,3 +142,34 @@ def read_user_by_id(
|
||||
detail="用户权限不足"
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
|
||||
def otp_generate(
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> Any:
|
||||
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||
|
||||
|
||||
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
|
||||
def otp_judge(
|
||||
data: dict,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> Any:
|
||||
uri = data.get("uri")
|
||||
otp_password = data.get("otpPassword")
|
||||
if not OtpUtils.is_legal(uri, otp_password):
|
||||
return schemas.Response(success=False, message="验证码错误")
|
||||
current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
|
||||
def otp_disable(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_active_user)
|
||||
) -> Any:
|
||||
current_user.update_otp_by_name(db, current_user.name, False, "")
|
||||
return schemas.Response(success=True)
|
||||
|
@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import verify_password
|
||||
from app.db import db_query, db_update, Base
|
||||
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
class User(Base):
|
||||
"""
|
||||
@ -23,15 +23,22 @@ class User(Base):
|
||||
is_superuser = Column(Boolean(), default=False)
|
||||
# 头像
|
||||
avatar = Column(String)
|
||||
# 是否启用otp二次验证
|
||||
is_otp = Column(Boolean(), default=False)
|
||||
# otp秘钥
|
||||
otp_secret = Column(String, default=None)
|
||||
|
||||
@staticmethod
|
||||
@db_query
|
||||
def authenticate(db: Session, name: str, password: str):
|
||||
def authenticate(db: Session, name: str, password: str, otp_password: str):
|
||||
user = db.query(User).filter(User.name == name).first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, str(user.hashed_password)):
|
||||
return None
|
||||
if user.is_otp:
|
||||
if not otp_password or not OtpUtils.check(user.otp_secret, otp_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
@ -45,3 +52,14 @@ class User(Base):
|
||||
if user:
|
||||
user.delete(db, user.id)
|
||||
return True
|
||||
|
||||
@db_update
|
||||
def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str):
|
||||
user = self.get_by_name(db, name)
|
||||
if user:
|
||||
user.update(db, {
|
||||
'is_otp': otp,
|
||||
'otp_secret': secret
|
||||
})
|
||||
return True
|
||||
return False
|
||||
|
@ -15,6 +15,8 @@ class UserBase(BaseModel):
|
||||
is_superuser: bool = False
|
||||
# 头像
|
||||
avatar: Optional[str] = None
|
||||
# 是否开启二次验证
|
||||
is_otp: Optional[bool] = False
|
||||
|
||||
|
||||
# Properties to receive via API on creation
|
||||
|
48
app/utils/otp.py
Normal file
48
app/utils/otp.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pyotp
|
||||
|
||||
|
||||
class OtpUtils:
|
||||
@staticmethod
|
||||
def generate_secret_key(username: str) -> (str, str):
|
||||
try:
|
||||
secret = pyotp.random_base32()
|
||||
uri = pyotp.totp.TOTP(secret).provisioning_uri(name='MoviePilot',
|
||||
issuer_name='MoviePilot(' + username + ')')
|
||||
return secret, uri
|
||||
except Exception as err:
|
||||
print(str(err))
|
||||
return "", ""
|
||||
|
||||
@staticmethod
|
||||
def is_legal(otp_uri: str, password: str) -> bool:
|
||||
"""
|
||||
校验二次验证是否正确
|
||||
"""
|
||||
try:
|
||||
return pyotp.TOTP(pyotp.parse_uri(otp_uri).secret).verify(password)
|
||||
except Exception as err:
|
||||
print(str(err))
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check(secret: str, password: str) -> bool:
|
||||
"""
|
||||
校验二次验证是否正确
|
||||
"""
|
||||
try:
|
||||
totp = pyotp.TOTP(secret)
|
||||
return totp.verify(password)
|
||||
except Exception as err:
|
||||
print(str(err))
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_secret(otp_uri: str) -> str:
|
||||
"""
|
||||
获取uri中的secret
|
||||
"""
|
||||
try:
|
||||
return pyotp.parse_uri(otp_uri).secret
|
||||
except Exception as err:
|
||||
print(str(err))
|
||||
return ""
|
30
database/versions/9cb3993e340e_1_0_17.py
Normal file
30
database/versions/9cb3993e340e_1_0_17.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""1_0_17
|
||||
|
||||
Revision ID: 9cb3993e340e
|
||||
Revises: d146dea51516
|
||||
Create Date: 2024-03-28 14:36:35.588392
|
||||
|
||||
"""
|
||||
import contextlib
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '9cb3993e340e'
|
||||
down_revision = 'd146dea51516'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with contextlib.suppress(Exception):
|
||||
with op.batch_alter_table("user") as batch_op:
|
||||
batch_op.add_column(sa.Column('is_otp', sa.BOOLEAN, server_default='0'))
|
||||
batch_op.add_column(sa.Column('otp_secret', sa.VARCHAR))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
@ -54,4 +54,5 @@ parse~=1.19.0
|
||||
docker~=6.1.3
|
||||
cachetools~=5.3.1
|
||||
fast-bencode~=1.1.3
|
||||
pystray~=0.19.5
|
||||
pystray~=0.19.5
|
||||
pypushdeer~=0.0.3
|
Loading…
x
Reference in New Issue
Block a user