package user import ( "context" "errors" "fmt" "github.com/zeromicro/go-zero/core/logx" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" "mingyang-admin-app-rpc/ent/user" "mingyang-admin-app-rpc/internal/jwt_manager" "mingyang-admin-app-rpc/internal/logic/cacherepo" "mingyang-admin-app-rpc/internal/svc" "mingyang-admin-app-rpc/internal/util" "mingyang-admin-app-rpc/types/app" "strings" "time" ) type RegisterUserLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger jwtManager *jwt_manager.JWTManager cacheRepo *cacherepo.CacheRepository } func NewRegisterUserLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RegisterUserLogic { return &RegisterUserLogic{ ctx: ctx, svcCtx: svcCtx, jwtManager: jwt_manager.NewJWTManager(&svcCtx.Config.JWTConf), Logger: logx.WithContext(ctx), cacheRepo: cacherepo.NewCacheRepository(ctx, svcCtx), } } // RegisterUser 用户注册 func (s *RegisterUserLogic) RegisterUser(req *app.RegisterUserRequest) (*app.RegisterUserResponse, error) { // 1. 验证输入 if err := s.validateRegisterRequest(req); err != nil { return nil, status.Errorf(codes.InvalidArgument, "validation failed: %v", err) } // 2. 验证验证码 if err := s.verifyRegistrationCode(s.ctx, req); err != nil { return nil, status.Errorf(codes.InvalidArgument, "verification failed: %v", err) } // 3. 验证密码强度 if err := util.ValidatePasswordStrength(req.GetPassword()); err != nil { return nil, status.Errorf(codes.InvalidArgument, "password validation failed: %v", err) } // 4. 哈希密码 passwordHash, err := util.HashPassword(req.GetPassword()) if err != nil { return nil, status.Errorf(codes.Internal, "failed to hash password: %v", err) } // 5. 标准化手机号格式 var mobile string if req.GetMobile() != "" { normalizedPhone := util.NormalizePhone(req.GetMobile()) mobile = normalizedPhone } // 6. 准备用户数据 userData := CreateUserData{ Username: strings.ToLower(strings.TrimSpace(req.GetUsername())), Email: strings.ToLower(strings.TrimSpace(req.GetEmail())), Mobile: mobile, PasswordHash: passwordHash, Gender: user.Gender(req.GetGender()), Nickname: req.GetNickName(), RegistrationSource: req.GetRegistrationSource(), Metadata: map[string]interface{}{ "registered_via": req.RegistrationSource, "registered_at": time.Now().Format(time.RFC3339), }, } userRepo := NewUser(s.ctx, s.svcCtx) user, err := userRepo.CreateUser(s.ctx, &userData) if err != nil { if strings.Contains(err.Error(), "already exists") { return nil, status.Errorf(codes.AlreadyExists, "user already exists") } return nil, status.Errorf(codes.Internal, "failed to create user: %v", err) } // 8. 生成JWT令牌 tokenPair, err := s.jwtManager.GenerateTokenPair(*user.Id) if err != nil { // 注意:这里不应该返回错误,因为用户已经创建成功 // 只是记录日志,继续处理 fmt.Printf("Failed to generate token for user %v: %v\n", user.Id, err) } // 10. 发送欢迎邮件(异步) go func() { //todo 发送欢迎邮件 }() // 11. 构建响应 response := &app.RegisterUserResponse{ User: user, } if tokenPair != nil { response.AuthToken = &app.AuthToken{ AccessToken: tokenPair.AccessToken, RefreshToken: tokenPair.RefreshToken, TokenType: tokenPair.TokenType, AccessTokenExpires: timeToProto(tokenPair.AccessTokenExpiresAt), RefreshTokenExpires: timeToProto(tokenPair.RefreshTokenExpiresAt), } } return response, nil } // validateRegisterRequest 验证注册请求 func (s *RegisterUserLogic) validateRegisterRequest(req *app.RegisterUserRequest) error { return nil } // verifyRegistrationCode 验证注册验证码 func (s *RegisterUserLogic) verifyRegistrationCode(ctx context.Context, req *app.RegisterUserRequest) error { var cacheKey string verificationType := req.GetVerificationType() if app.AccountType(verificationType) == app.AccountType_EMAIL { cacheKey = fmt.Sprintf("email_verification_code:%s", req.GetEmail()) } else if app.AccountType(verificationType) == app.AccountType_MOBILE { cacheKey = fmt.Sprintf("mobile_verification_code:%s", req.GetMobile()) } else { return errors.New("invalid verification type") } fmt.Printf("cacheKey: %s", cacheKey) // 获取验证码数据 codeData, err := s.cacheRepo.GetVerificationCode(ctx, cacheKey) if err != nil { fmt.Printf("Failed to get verification code: %v\n", err) return fmt.Errorf("failed to get verification code: %w", err) } if codeData == nil { return errors.New("verification code expired or not found") } // 验证验证码 if codeData.Code != req.GetCaptcha() { // 增加尝试次数 if err := s.cacheRepo.IncrementVerificationAttempts(ctx, cacheKey); err != nil { fmt.Printf("Failed to increment verification attempts: %v\n", err) } return errors.New("invalid verification code") } /*// 验证成功,删除验证码 if err := s.cacheRepo.DeleteVerificationCode(ctx, cacheKey); err != nil { fmt.Printf("Failed to delete verification code: %v\n", err) }*/ return nil } func timeToProto(t time.Time) *timestamppb.Timestamp { pbTimestamp := timestamppb.New(t) if err := pbTimestamp.CheckValid(); err != nil { return nil } return pbTimestamp }