feat #1763
This commit is contained in:
parent
ff07841dd6
commit
40d99f1dd5
@ -1,7 +1,7 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Form
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@ -15,13 +15,16 @@ from app.db import get_db
|
|||||||
from app.db.models.user import User
|
from app.db.models.user import User
|
||||||
from app.log import logger
|
from app.log import logger
|
||||||
from app.utils.web import WebUtils
|
from app.utils.web import WebUtils
|
||||||
|
from app.utils.otp import OtpUtils
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/access-token", summary="获取token", response_model=schemas.Token)
|
@router.post("/access-token", summary="获取token", response_model=schemas.Token)
|
||||||
async def login_access_token(
|
async def login_access_token(
|
||||||
db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends()
|
db: Session = Depends(get_db),
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
otp_password: str = Form(None)
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
获取认证Token
|
获取认证Token
|
||||||
@ -30,7 +33,8 @@ async def login_access_token(
|
|||||||
user = User.authenticate(
|
user = User.authenticate(
|
||||||
db=db,
|
db=db,
|
||||||
name=form_data.username,
|
name=form_data.username,
|
||||||
password=form_data.password
|
password=form_data.password,
|
||||||
|
otp_password=otp_password
|
||||||
)
|
)
|
||||||
if not user:
|
if not user:
|
||||||
# 请求协助认证
|
# 请求协助认证
|
||||||
@ -38,7 +42,7 @@ async def login_access_token(
|
|||||||
token = UserChain().user_authenticate(form_data.username, form_data.password)
|
token = UserChain().user_authenticate(form_data.username, form_data.password)
|
||||||
if not token:
|
if not token:
|
||||||
logger.warn(f"用户 {form_data.username} 登录失败!")
|
logger.warn(f"用户 {form_data.username} 登录失败!")
|
||||||
raise HTTPException(status_code=401, detail="用户名或密码不正确")
|
raise HTTPException(status_code=401, detail="用户名、密码、二次校验不正确")
|
||||||
else:
|
else:
|
||||||
logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...")
|
logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...")
|
||||||
# 加入用户信息表
|
# 加入用户信息表
|
||||||
|
@ -10,6 +10,7 @@ from app.core.security import get_password_hash
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.db.models.user import User
|
from app.db.models.user import User
|
||||||
from app.db.userauth import get_current_active_superuser, get_current_active_user
|
from app.db.userauth import get_current_active_superuser, get_current_active_user
|
||||||
|
from app.utils.otp import OtpUtils
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -141,3 +142,34 @@ def read_user_by_id(
|
|||||||
detail="用户权限不足"
|
detail="用户权限不足"
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
|
||||||
|
def otp_generate(
|
||||||
|
current_user: User = Depends(get_current_active_user)
|
||||||
|
) -> Any:
|
||||||
|
secret, uri = OtpUtils.generate_secret_key(current_user.name)
|
||||||
|
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
|
||||||
|
def otp_judge(
|
||||||
|
data: dict,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_active_user)
|
||||||
|
) -> Any:
|
||||||
|
uri = data.get("uri")
|
||||||
|
otp_password = data.get("otpPassword")
|
||||||
|
if not OtpUtils.is_legal(uri, otp_password):
|
||||||
|
return schemas.Response(success=False, message="验证码错误")
|
||||||
|
current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
|
||||||
|
return schemas.Response(success=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
|
||||||
|
def otp_disable(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_active_user)
|
||||||
|
) -> Any:
|
||||||
|
current_user.update_otp_by_name(db, current_user.name, False, "")
|
||||||
|
return schemas.Response(success=True)
|
||||||
|
@ -3,7 +3,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.security import verify_password
|
from app.core.security import verify_password
|
||||||
from app.db import db_query, db_update, Base
|
from app.db import db_query, db_update, Base
|
||||||
|
from app.utils.otp import OtpUtils
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
"""
|
"""
|
||||||
@ -23,15 +23,22 @@ class User(Base):
|
|||||||
is_superuser = Column(Boolean(), default=False)
|
is_superuser = Column(Boolean(), default=False)
|
||||||
# 头像
|
# 头像
|
||||||
avatar = Column(String)
|
avatar = Column(String)
|
||||||
|
# 是否启用otp二次验证
|
||||||
|
is_otp = Column(Boolean(), default=False)
|
||||||
|
# otp秘钥
|
||||||
|
otp_secret = Column(String, default=None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@db_query
|
@db_query
|
||||||
def authenticate(db: Session, name: str, password: str):
|
def authenticate(db: Session, name: str, password: str, otp_password: str):
|
||||||
user = db.query(User).filter(User.name == name).first()
|
user = db.query(User).filter(User.name == name).first()
|
||||||
if not user:
|
if not user:
|
||||||
return None
|
return None
|
||||||
if not verify_password(password, str(user.hashed_password)):
|
if not verify_password(password, str(user.hashed_password)):
|
||||||
return None
|
return None
|
||||||
|
if user.is_otp:
|
||||||
|
if not otp_password or not OtpUtils.check(user.otp_secret, otp_password):
|
||||||
|
return None
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -45,3 +52,14 @@ class User(Base):
|
|||||||
if user:
|
if user:
|
||||||
user.delete(db, user.id)
|
user.delete(db, user.id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@db_update
|
||||||
|
def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str):
|
||||||
|
user = self.get_by_name(db, name)
|
||||||
|
if user:
|
||||||
|
user.update(db, {
|
||||||
|
'is_otp': otp,
|
||||||
|
'otp_secret': secret
|
||||||
|
})
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
@ -15,6 +15,8 @@ class UserBase(BaseModel):
|
|||||||
is_superuser: bool = False
|
is_superuser: bool = False
|
||||||
# 头像
|
# 头像
|
||||||
avatar: Optional[str] = None
|
avatar: Optional[str] = None
|
||||||
|
# 是否开启二次验证
|
||||||
|
is_otp: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
# Properties to receive via API on creation
|
# Properties to receive via API on creation
|
||||||
|
48
app/utils/otp.py
Normal file
48
app/utils/otp.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import pyotp
|
||||||
|
|
||||||
|
|
||||||
|
class OtpUtils:
|
||||||
|
@staticmethod
|
||||||
|
def generate_secret_key(username: str) -> (str, str):
|
||||||
|
try:
|
||||||
|
secret = pyotp.random_base32()
|
||||||
|
uri = pyotp.totp.TOTP(secret).provisioning_uri(name='MoviePilot',
|
||||||
|
issuer_name='MoviePilot(' + username + ')')
|
||||||
|
return secret, uri
|
||||||
|
except Exception as err:
|
||||||
|
print(str(err))
|
||||||
|
return "", ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_legal(otp_uri: str, password: str) -> bool:
|
||||||
|
"""
|
||||||
|
校验二次验证是否正确
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return pyotp.TOTP(pyotp.parse_uri(otp_uri).secret).verify(password)
|
||||||
|
except Exception as err:
|
||||||
|
print(str(err))
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check(secret: str, password: str) -> bool:
|
||||||
|
"""
|
||||||
|
校验二次验证是否正确
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
totp = pyotp.TOTP(secret)
|
||||||
|
return totp.verify(password)
|
||||||
|
except Exception as err:
|
||||||
|
print(str(err))
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_secret(otp_uri: str) -> str:
|
||||||
|
"""
|
||||||
|
获取uri中的secret
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return pyotp.parse_uri(otp_uri).secret
|
||||||
|
except Exception as err:
|
||||||
|
print(str(err))
|
||||||
|
return ""
|
30
database/versions/9cb3993e340e_1_0_17.py
Normal file
30
database/versions/9cb3993e340e_1_0_17.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""1_0_17
|
||||||
|
|
||||||
|
Revision ID: 9cb3993e340e
|
||||||
|
Revises: d146dea51516
|
||||||
|
Create Date: 2024-03-28 14:36:35.588392
|
||||||
|
|
||||||
|
"""
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '9cb3993e340e'
|
||||||
|
down_revision = 'd146dea51516'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
with op.batch_alter_table("user") as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('is_otp', sa.BOOLEAN, server_default='0'))
|
||||||
|
batch_op.add_column(sa.Column('otp_secret', sa.VARCHAR))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
pass
|
@ -54,4 +54,5 @@ parse~=1.19.0
|
|||||||
docker~=6.1.3
|
docker~=6.1.3
|
||||||
cachetools~=5.3.1
|
cachetools~=5.3.1
|
||||||
fast-bencode~=1.1.3
|
fast-bencode~=1.1.3
|
||||||
pystray~=0.19.5
|
pystray~=0.19.5
|
||||||
|
pypushdeer~=0.0.3
|
Loading…
x
Reference in New Issue
Block a user