216 lines
6.5 KiB
Go
216 lines
6.5 KiB
Go
package user
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/saas-mingyang/mingyang-admin-common/utils/sonyflake"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
"mingyang-admin-app-rpc/ent"
|
|
"mingyang-admin-app-rpc/ent/predicate"
|
|
"mingyang-admin-app-rpc/ent/user"
|
|
"mingyang-admin-app-rpc/ent/userloginlog"
|
|
"mingyang-admin-app-rpc/internal/jwt_manager"
|
|
"time"
|
|
|
|
"mingyang-admin-app-rpc/internal/svc"
|
|
"mingyang-admin-app-rpc/types/app"
|
|
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|
)
|
|
|
|
type LoginUserLogic struct {
|
|
ctx context.Context
|
|
svcCtx *svc.ServiceContext
|
|
logx.Logger
|
|
jwt *jwt_manager.JWTManager
|
|
}
|
|
|
|
func NewLoginUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginUserLogic {
|
|
return &LoginUserLogic{
|
|
ctx: ctx,
|
|
svcCtx: svcCtx,
|
|
Logger: logx.WithContext(ctx),
|
|
jwt: jwt_manager.NewJWTManager(&svcCtx.Config.JWTConf),
|
|
}
|
|
}
|
|
|
|
// LoginUser 用户登录
|
|
func (l *LoginUserLogic) LoginUser(in *app.LoginRequest) (*app.LoginResponse, error) {
|
|
// 1. 参数验证
|
|
if err := l.validateLoginRequest(in); err != nil {
|
|
return nil, status.Error(codes.InvalidArgument, err.Error())
|
|
}
|
|
// 2. 根据登录标识查询用户
|
|
user, err := l.getUserByIdentifier(in)
|
|
if err != nil {
|
|
// 避免泄露用户是否存在的信息,统一返回相同错误
|
|
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
|
}
|
|
// 3. 检查账户状态
|
|
if err := l.checkAccountStatus(user); err != nil {
|
|
return nil, err
|
|
}
|
|
// 4. 验证密码
|
|
if !l.verifyPassword(user, in.GetPassword()) {
|
|
// 记录登录失败尝试
|
|
go l.recordLogin(user, in, nil, false)
|
|
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
|
}
|
|
// 5. 生成访问令牌和刷新令牌
|
|
tokenPair, userThirdAuth, err := l.generateTokenPair(user, in)
|
|
fmt.Printf("userThirdAuth: %v", userThirdAuth)
|
|
if err != nil {
|
|
fmt.Printf("generateTokenPair error: %v", err)
|
|
return nil, status.Error(codes.Internal, "failed to generate token")
|
|
}
|
|
// 6. 记录成功登录日志
|
|
go l.recordLogin(user, in, userThirdAuth, true)
|
|
// 7. 返回响应
|
|
return l.buildLoginResponse(user, tokenPair), nil
|
|
}
|
|
|
|
func (l *LoginUserLogic) validateLoginRequest(in *app.LoginRequest) error {
|
|
// 检查登录标识是否提供
|
|
if in.GetUsername() == "" {
|
|
return errors.New("username, mobile or email is required")
|
|
}
|
|
// 检查密码是否提供
|
|
if in.GetPassword() == "" {
|
|
return errors.New("password is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *LoginUserLogic) getUserByIdentifier(in *app.LoginRequest) (*ent.User, error) {
|
|
var conditions []predicate.User
|
|
// 根据提供的标识构建查询条件
|
|
conditions = append(conditions, user.UsernameEQ(in.GetUsername()))
|
|
// 查询用户,同时加载关联数据(如角色、权限等)
|
|
user, err := l.svcCtx.DB.User.Query().
|
|
Where(user.Or(conditions...)).
|
|
Only(l.ctx)
|
|
if err != nil {
|
|
if ent.IsNotFound(err) {
|
|
// 用户不存在,返回统一错误信息
|
|
return nil, errors.New("user not found")
|
|
}
|
|
return nil, fmt.Errorf("failed to query user: %w", err)
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (l *LoginUserLogic) checkAccountStatus(user *ent.User) error {
|
|
switch user.AccountStatus {
|
|
case "locked":
|
|
// 锁定已过期,可以尝试解锁
|
|
case "suspended":
|
|
return errors.New("account suspended")
|
|
case "banned":
|
|
return errors.New("account banned")
|
|
case "deleted":
|
|
return errors.New("account deleted")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *LoginUserLogic) verifyPassword(user *ent.User, password string) bool {
|
|
// 使用bcrypt验证密码
|
|
err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password))
|
|
if err != nil {
|
|
// 密码错误,增加登录尝试次数
|
|
l.incrementLoginAttempts(user)
|
|
return false
|
|
}
|
|
// 密码正确,重置登录尝试次数
|
|
l.resetLoginAttempts(user)
|
|
return true
|
|
}
|
|
|
|
func (l *LoginUserLogic) recordLogin(user *ent.User, in *app.LoginRequest, userThirdAuth *ent.UserThirdAuth, loginResult bool) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
_, err := l.svcCtx.DB.UserLoginLog.Create().
|
|
SetUserID(user.ID).
|
|
SetLoginTime(time.Now()).
|
|
SetLoginIP(in.GetClientIp()).
|
|
SetLoginType(userloginlog.LoginType(in.GetLoginTyp())).
|
|
SetLoginPlatform(userloginlog.LoginPlatform(in.GetLoginPlatform())).
|
|
SetFailureReason("password_incorrect").
|
|
SetLoginResult(loginResult).
|
|
SetAuthID(userThirdAuth.ID).
|
|
Save(ctx)
|
|
if err != nil {
|
|
fmt.Printf("failed to record login: %v", err)
|
|
l.Logger.Errorf("failed to record failed login: %v", err)
|
|
}
|
|
}
|
|
|
|
func (l *LoginUserLogic) generateTokenPair(user *ent.User, in *app.LoginRequest) (*jwt_manager.TokenPair, *ent.UserThirdAuth, error) {
|
|
// 生成访问令牌
|
|
tokenPair, err := l.jwt.GenerateTokenPair(user.ID)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("failed to generate access token: %w", err)
|
|
}
|
|
// 保存令牌到数据库
|
|
userThirdAuth, err := l.saveRefreshToken(user, tokenPair, in)
|
|
return tokenPair, userThirdAuth, err
|
|
}
|
|
|
|
// 构建登录响应
|
|
func (l *LoginUserLogic) buildLoginResponse(user *ent.User, tokenPair *jwt_manager.TokenPair) *app.LoginResponse {
|
|
response := &app.LoginResponse{
|
|
User: l.convertUserToProto(user),
|
|
AuthToken: &app.AuthToken{
|
|
AccessToken: tokenPair.AccessToken,
|
|
RefreshToken: tokenPair.RefreshToken,
|
|
TokenType: tokenPair.TokenType,
|
|
AccessTokenExpires: timestamppb.New(tokenPair.AccessTokenExpiresAt),
|
|
RefreshTokenExpires: timestamppb.New(tokenPair.RefreshTokenExpiresAt),
|
|
},
|
|
}
|
|
return response
|
|
}
|
|
|
|
// 增加登录尝试次数
|
|
func (l *LoginUserLogic) incrementLoginAttempts(user *ent.User) {
|
|
|
|
}
|
|
|
|
// 重置登录尝试次数
|
|
func (l *LoginUserLogic) resetLoginAttempts(user *ent.User) {
|
|
|
|
}
|
|
|
|
// 保存刷新令牌到数据库
|
|
func (l *LoginUserLogic) saveRefreshToken(user *ent.User, tokenPair *jwt_manager.TokenPair, in *app.LoginRequest) (*ent.UserThirdAuth, error) {
|
|
userThirdAuth, err := l.svcCtx.DB.UserThirdAuth.Create().
|
|
SetUserID(user.ID).
|
|
SetRefreshToken(tokenPair.RefreshToken).
|
|
SetAccessToken(tokenPair.AccessToken).
|
|
SetAccessTokenExpiry(tokenPair.AccessTokenExpiresAt).
|
|
SetBoundAt(time.Now()).
|
|
SetID(sonyflake.NextID()).
|
|
SetEmail(user.Email).
|
|
SetMobile(*user.Mobile).
|
|
Save(l.ctx)
|
|
if err != nil {
|
|
l.Errorf("saveRefreshToken err: %v", err)
|
|
return nil, err
|
|
}
|
|
return userThirdAuth, nil
|
|
}
|
|
|
|
func (l *LoginUserLogic) convertUserToProto(u *ent.User) *app.UserInfo {
|
|
return &app.UserInfo{
|
|
Id: &u.ID,
|
|
Username: &u.Username,
|
|
Nickname: &u.Nickname,
|
|
Avatar: u.Avatar,
|
|
Email: &u.Email,
|
|
}
|
|
}
|