package user import ( "context" "errors" "fmt" "github.com/saas-mingyang/mingyang-admin-common/utils/sonyflake" "golang.org/x/crypto/bcrypt" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" "mingyang-admin-app-rpc/ent" "mingyang-admin-app-rpc/ent/predicate" "mingyang-admin-app-rpc/ent/user" "mingyang-admin-app-rpc/ent/userloginlog" "mingyang-admin-app-rpc/internal/jwt_manager" "time" "mingyang-admin-app-rpc/internal/svc" "mingyang-admin-app-rpc/types/app" "github.com/zeromicro/go-zero/core/logx" ) type LoginUserLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger jwt *jwt_manager.JWTManager } func NewLoginUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *LoginUserLogic { return &LoginUserLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), jwt: jwt_manager.NewJWTManager(&svcCtx.Config.JWTConf), } } // LoginUser 用户登录 func (l *LoginUserLogic) LoginUser(in *app.LoginRequest) (*app.LoginResponse, error) { // 1. 参数验证 if err := l.validateLoginRequest(in); err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } // 2. 根据登录标识查询用户 user, err := l.getUserByIdentifier(in) if err != nil { // 避免泄露用户是否存在的信息,统一返回相同错误 return nil, status.Error(codes.Unauthenticated, "invalid credentials") } // 3. 检查账户状态 if err := l.checkAccountStatus(user); err != nil { return nil, err } // 4. 验证密码 if !l.verifyPassword(user, in.GetPassword()) { // 记录登录失败尝试 go l.recordLogin(user, in, nil, false) return nil, status.Error(codes.Unauthenticated, "invalid credentials") } // 5. 生成访问令牌和刷新令牌 tokenPair, userThirdAuth, err := l.generateTokenPair(user, in) fmt.Printf("userThirdAuth: %v", userThirdAuth) if err != nil { fmt.Printf("generateTokenPair error: %v", err) return nil, status.Error(codes.Internal, "failed to generate token") } // 6. 记录成功登录日志 go l.recordLogin(user, in, userThirdAuth, true) // 7. 返回响应 return l.buildLoginResponse(user, tokenPair), nil } func (l *LoginUserLogic) validateLoginRequest(in *app.LoginRequest) error { // 检查登录标识是否提供 if in.GetUsername() == "" { return errors.New("username, mobile or email is required") } // 检查密码是否提供 if in.GetPassword() == "" { return errors.New("password is required") } return nil } func (l *LoginUserLogic) getUserByIdentifier(in *app.LoginRequest) (*ent.User, error) { var conditions []predicate.User // 根据提供的标识构建查询条件 conditions = append(conditions, user.UsernameEQ(in.GetUsername())) // 查询用户,同时加载关联数据(如角色、权限等) user, err := l.svcCtx.DB.User.Query(). Where(user.Or(conditions...)). Only(l.ctx) if err != nil { if ent.IsNotFound(err) { // 用户不存在,返回统一错误信息 return nil, errors.New("user not found") } return nil, fmt.Errorf("failed to query user: %w", err) } return user, nil } func (l *LoginUserLogic) checkAccountStatus(user *ent.User) error { switch user.AccountStatus { case "locked": // 锁定已过期,可以尝试解锁 case "suspended": return errors.New("account suspended") case "banned": return errors.New("account banned") case "deleted": return errors.New("account deleted") } return nil } func (l *LoginUserLogic) verifyPassword(user *ent.User, password string) bool { // 使用bcrypt验证密码 err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)) if err != nil { // 密码错误,增加登录尝试次数 l.incrementLoginAttempts(user) return false } // 密码正确,重置登录尝试次数 l.resetLoginAttempts(user) return true } func (l *LoginUserLogic) recordLogin(user *ent.User, in *app.LoginRequest, userThirdAuth *ent.UserThirdAuth, loginResult bool) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, err := l.svcCtx.DB.UserLoginLog.Create(). SetUserID(user.ID). SetLoginTime(time.Now()). SetLoginIP(in.GetClientIp()). SetLoginType(userloginlog.LoginType(in.GetLoginTyp())). SetLoginPlatform(userloginlog.LoginPlatform(in.GetLoginPlatform())). SetFailureReason("password_incorrect"). SetLoginResult(loginResult). SetAuthID(userThirdAuth.ID). Save(ctx) if err != nil { fmt.Printf("failed to record login: %v", err) l.Logger.Errorf("failed to record failed login: %v", err) } } func (l *LoginUserLogic) generateTokenPair(user *ent.User, in *app.LoginRequest) (*jwt_manager.TokenPair, *ent.UserThirdAuth, error) { // 生成访问令牌 tokenPair, err := l.jwt.GenerateTokenPair(user.ID) if err != nil { return nil, nil, fmt.Errorf("failed to generate access token: %w", err) } // 保存令牌到数据库 userThirdAuth, err := l.saveRefreshToken(user, tokenPair, in) return tokenPair, userThirdAuth, err } // 构建登录响应 func (l *LoginUserLogic) buildLoginResponse(user *ent.User, tokenPair *jwt_manager.TokenPair) *app.LoginResponse { response := &app.LoginResponse{ User: l.convertUserToProto(user), AuthToken: &app.AuthToken{ AccessToken: tokenPair.AccessToken, RefreshToken: tokenPair.RefreshToken, TokenType: tokenPair.TokenType, AccessTokenExpires: timestamppb.New(tokenPair.AccessTokenExpiresAt), RefreshTokenExpires: timestamppb.New(tokenPair.RefreshTokenExpiresAt), }, } return response } // 增加登录尝试次数 func (l *LoginUserLogic) incrementLoginAttempts(user *ent.User) { } // 重置登录尝试次数 func (l *LoginUserLogic) resetLoginAttempts(user *ent.User) { } // 保存刷新令牌到数据库 func (l *LoginUserLogic) saveRefreshToken(user *ent.User, tokenPair *jwt_manager.TokenPair, in *app.LoginRequest) (*ent.UserThirdAuth, error) { userThirdAuth, err := l.svcCtx.DB.UserThirdAuth.Create(). SetUserID(user.ID). SetRefreshToken(tokenPair.RefreshToken). SetAccessToken(tokenPair.AccessToken). SetAccessTokenExpiry(tokenPair.AccessTokenExpiresAt). SetBoundAt(time.Now()). SetID(sonyflake.NextID()). SetEmail(user.Email). SetMobile(*user.Mobile). Save(l.ctx) if err != nil { l.Errorf("saveRefreshToken err: %v", err) return nil, err } return userThirdAuth, nil } func (l *LoginUserLogic) convertUserToProto(u *ent.User) *app.UserInfo { return &app.UserInfo{ Id: &u.ID, Username: &u.Username, Nickname: &u.Nickname, Avatar: u.Avatar, Email: &u.Email, } }