package jwt_manager import ( "context" "crypto/md5" "encoding/hex" "errors" "fmt" "mingyang-admin-app-rpc/internal/config" "time" "github.com/golang-jwt/jwt/v5" "github.com/redis/go-redis/v9" ) // TokenType 令牌类型 type TokenType string const ( AccessToken TokenType = "access" RefreshToken TokenType = "refresh" ) // JWTConfig JWT 配置 // Claims JWT Claims 结构体 type Claims struct { UserID uint64 `json:"user_id"` Type TokenType `json:"type"` jwt.RegisteredClaims } // TokenPair 令牌对 type TokenPair struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` AccessTokenExpiresAt time.Time `json:"access_token_expires_at"` RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"` TokenType string `json:"token_type"` } // JWTManager JWT 管理器 type JWTManager struct { accessSecret []byte refreshSecret []byte accessExpiry time.Duration refreshExpiry time.Duration issuer string redis redis.UniversalClient prefix string // Redis 键前缀 ctx context.Context } // NewJWTManager 创建新的 JWT 管理器 func NewJWTManager(config *config.JWTConfig, redisClient redis.UniversalClient) *JWTManager { if config.AccessTokenSecret == "" || config.RefreshTokenSecret == "" { panic("JWT 配置错误") } return &JWTManager{ accessSecret: []byte(config.AccessTokenSecret), refreshSecret: []byte(config.RefreshTokenSecret), accessExpiry: config.AccessTokenExpiry, refreshExpiry: config.RefreshTokenExpiry, issuer: config.Issuer, redis: redisClient, prefix: "jwt:", ctx: context.Background(), } } // WithContext 设置上下文 func (m *JWTManager) WithContext(ctx context.Context) *JWTManager { m.ctx = ctx return m } // ==================== Token 生成方法 ==================== // GenerateTokenPair 生成访问令牌和刷新令牌对 func (m *JWTManager) GenerateTokenPair(userID uint64) (*TokenPair, error) { // 生成访问令牌 accessToken, accessClaims, err := m.generateToken(userID, AccessToken, m.accessSecret, m.accessExpiry) fmt.Printf("accessToken: %v\n", accessToken) if err != nil { return nil, fmt.Errorf("生成访问令牌失败: %w", err) } // 生成刷新令牌 refreshToken, refreshClaims, err := m.generateToken(userID, RefreshToken, m.refreshSecret, m.refreshExpiry) if err != nil { return nil, fmt.Errorf("生成刷新令牌失败: %w", err) } // 存储令牌到 Redis(用于追踪和管理) if m.redis != nil { m.storeToken(userID, accessToken, accessClaims.ExpiresAt.Time) m.storeToken(userID, refreshToken, refreshClaims.ExpiresAt.Time) } return &TokenPair{ AccessToken: accessToken, RefreshToken: refreshToken, AccessTokenExpiresAt: accessClaims.ExpiresAt.Time, RefreshTokenExpiresAt: refreshClaims.ExpiresAt.Time, TokenType: "Bearer", }, nil } // GenerateAccessToken 生成访问令牌 func (m *JWTManager) GenerateAccessToken(userID uint64) (string, *Claims, error) { token, claims, err := m.generateToken(userID, AccessToken, m.accessSecret, m.accessExpiry) if err != nil { return "", nil, err } // 存储到 Redis if m.redis != nil { m.storeToken(userID, token, claims.ExpiresAt.Time) } return token, claims, nil } // GenerateRefreshToken 生成刷新令牌 func (m *JWTManager) GenerateRefreshToken(userID uint64) (string, *Claims, error) { token, claims, err := m.generateToken(userID, RefreshToken, m.refreshSecret, m.refreshExpiry) if err != nil { return "", nil, err } // 存储到 Redis if m.redis != nil { m.storeToken(userID, token, claims.ExpiresAt.Time) } return token, claims, nil } // generateToken 内部方法:生成令牌 func (m *JWTManager) generateToken(userID uint64, tokenType TokenType, secret []byte, expiry time.Duration) (string, *Claims, error) { now := time.Now() expireAt := now.Add(expiry) claims := &Claims{ UserID: userID, Type: tokenType, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(expireAt), IssuedAt: jwt.NewNumericDate(now), Issuer: m.issuer, Subject: fmt.Sprintf("%d", userID), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(secret) if err != nil { return "", nil, fmt.Errorf("签名令牌失败: %w", err) } return tokenString, claims, nil } // ==================== Token 验证方法 ==================== // VerifyAccessToken 验证访问令牌(包含黑名单检查) func (m *JWTManager) VerifyAccessToken(tokenString string) (*Claims, error) { return m.verifyToken(tokenString, AccessToken, m.accessSecret) } // VerifyRefreshToken 验证刷新令牌(包含黑名单检查) func (m *JWTManager) VerifyRefreshToken(tokenString string) (*Claims, error) { return m.verifyToken(tokenString, RefreshToken, m.refreshSecret) } // verifyToken 内部方法:验证令牌 func (m *JWTManager) verifyToken(tokenString string, expectedType TokenType, secret []byte) (*Claims, error) { // 1. 检查令牌是否在黑名单中 if m.redis != nil { blacklisted, err := m.isTokenBlacklisted(tokenString, expectedType) if err != nil { return nil, fmt.Errorf("检查黑名单失败: %w", err) } if blacklisted { return nil, errors.New("令牌已失效,请重新登录") } } // 2. 验证 JWT token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("不支持的签名方法: %v", token.Header["alg"]) } return secret, nil }) if err != nil { return nil, fmt.Errorf("解析令牌失败: %w", err) } if claims, ok := token.Claims.(*Claims); ok && token.Valid { // 3. 检查令牌类型 if claims.Type != expectedType { return nil, errors.New("令牌类型不匹配") } // 4. 检查令牌是否过期 if claims.ExpiresAt.Time.Before(time.Now()) { return nil, errors.New("令牌已过期") } return claims, nil } return nil, errors.New("无效的令牌") } // ==================== 黑名单管理 ==================== // BlacklistToken 将令牌加入黑名单 func (m *JWTManager) BlacklistToken(tokenString string) error { if m.redis == nil { return errors.New("redis 客户端未初始化") } // 尝试解析令牌 claims, err := m.parseTokenWithoutBlacklistCheck(tokenString) if err != nil { return fmt.Errorf("无法识别令牌: %w", err) } return m.addToBlacklist(tokenString, claims.Type, claims.ExpiresAt.Time) } // BlacklistAccessToken 将访问令牌加入黑名单 func (m *JWTManager) BlacklistAccessToken(tokenString string) error { if m.redis == nil { return errors.New("redis 客户端未初始化") } claims, err := m.VerifyAccessToken(tokenString) if err != nil { return fmt.Errorf("无效的访问令牌: %w", err) } return m.addToBlacklist(tokenString, AccessToken, claims.ExpiresAt.Time) } // BlacklistRefreshToken 将刷新令牌加入黑名单 func (m *JWTManager) BlacklistRefreshToken(tokenString string) error { if m.redis == nil { return errors.New("redis 客户端未初始化") } claims, err := m.VerifyRefreshToken(tokenString) if err != nil { return fmt.Errorf("无效的刷新令牌: %w", err) } return m.addToBlacklist(tokenString, RefreshToken, claims.ExpiresAt.Time) } // addToBlacklist 内部方法:将令牌加入黑名单 func (m *JWTManager) addToBlacklist(tokenString string, tokenType TokenType, expireTime time.Time) error { remaining := time.Until(expireTime) if remaining <= 0 { // 令牌已过期,不需要加入黑名单 return nil } tokenHash := m.hashToken(tokenString) blacklistKey := m.getBlacklistKey(tokenHash, tokenType) // 设置黑名单,过期时间等于令牌剩余有效期 err := m.redis.Set(m.ctx, blacklistKey, "1", remaining).Err() if err != nil { return fmt.Errorf("设置黑名单失败: %w", err) } // 记录用户与令牌的关联(用于按用户清理) if claims, err := m.parseTokenWithoutBlacklistCheck(tokenString); err == nil { userTokenKey := m.getUserTokenKey(claims.UserID, tokenHash) m.redis.Set(m.ctx, userTokenKey, tokenString, remaining) } return nil } // IsTokenBlacklisted 检查令牌是否在黑名单中 func (m *JWTManager) IsTokenBlacklisted(tokenString string) (bool, error) { if m.redis == nil { return false, errors.New("redis 客户端未初始化") } // 检查所有类型的黑名单 tokenHash := m.hashToken(tokenString) // 检查访问令牌黑名单 accessBlacklistKey := m.getBlacklistKey(tokenHash, AccessToken) exists, err := m.redis.Exists(m.ctx, accessBlacklistKey).Result() if err != nil { return false, fmt.Errorf("检查访问令牌黑名单失败: %w", err) } if exists > 0 { return true, nil } // 检查刷新令牌黑名单 refreshBlacklistKey := m.getBlacklistKey(tokenHash, RefreshToken) exists, err = m.redis.Exists(m.ctx, refreshBlacklistKey).Result() if err != nil { return false, fmt.Errorf("检查刷新令牌黑名单失败: %w", err) } return exists > 0, nil } // isTokenBlacklisted 内部方法:检查令牌是否在黑名单中 func (m *JWTManager) isTokenBlacklisted(tokenString string, tokenType TokenType) (bool, error) { if m.redis == nil { return false, nil } tokenHash := m.hashToken(tokenString) blacklistKey := m.getBlacklistKey(tokenHash, tokenType) exists, err := m.redis.Exists(m.ctx, blacklistKey).Result() if err != nil { return false, err } return exists > 0, nil } // RemoveFromBlacklist 从黑名单中移除令牌 func (m *JWTManager) RemoveFromBlacklist(tokenString string) error { if m.redis == nil { return errors.New("redis 客户端未初始化") } tokenHash := m.hashToken(tokenString) // 尝试移除两种类型的黑名单 accessBlacklistKey := m.getBlacklistKey(tokenHash, AccessToken) refreshBlacklistKey := m.getBlacklistKey(tokenHash, RefreshToken) err := m.redis.Del(m.ctx, accessBlacklistKey, refreshBlacklistKey).Err() return err } // RevokeUserTokens 撤销用户所有令牌 func (m *JWTManager) RevokeUserTokens(userID uint64) error { if m.redis == nil { return errors.New("redis 客户端未初始化") } // 获取用户的所有令牌 tokens, err := m.GetUserTokens(userID) if err != nil { return fmt.Errorf("获取用户令牌失败: %w", err) } // 将所有令牌加入黑名单 for _, token := range tokens { m.BlacklistToken(token) } // 清理用户令牌集合 userTokenSetKey := m.getUserTokenSetKey(userID) err = m.redis.Del(m.ctx, userTokenSetKey).Err() if err != nil { return fmt.Errorf("清理用户令牌集合失败: %w", err) } return nil } // ==================== 用户令牌管理 ==================== // storeToken 存储用户令牌 func (m *JWTManager) storeToken(userID uint64, tokenString string, expireTime time.Time) { if m.redis == nil { return } // 将令牌添加到用户的令牌集合 userTokenSetKey := m.getUserTokenSetKey(userID) err := m.redis.SAdd(m.ctx, userTokenSetKey, tokenString).Err() if err != nil { // 记录错误但继续执行 fmt.Printf("存储用户令牌失败: %v\n", err) return } // 设置集合的过期时间(比令牌晚 1 小时) expireAt := expireTime.Add(time.Hour) err = m.redis.ExpireAt(m.ctx, userTokenSetKey, expireAt).Err() if err != nil { fmt.Printf("设置集合过期时间失败: %v\n", err) } } // GetUserTokens 获取用户的所有令牌 func (m *JWTManager) GetUserTokens(userID uint64) ([]string, error) { if m.redis == nil { return nil, errors.New("redis 客户端未初始化") } userTokenSetKey := m.getUserTokenSetKey(userID) return m.redis.SMembers(m.ctx, userTokenSetKey).Result() } // RemoveUserToken 移除用户的特定令牌 func (m *JWTManager) RemoveUserToken(userID uint64, tokenString string) error { if m.redis == nil { return errors.New("redis 客户端未初始化") } userTokenSetKey := m.getUserTokenSetKey(userID) err := m.redis.SRem(m.ctx, userTokenSetKey, tokenString).Err() return err } // ==================== 令牌刷新 ==================== // RefreshTokens 使用刷新令牌获取新的令牌对 func (m *JWTManager) RefreshTokens(refreshToken string) (*TokenPair, error) { // 验证刷新令牌 claims, err := m.VerifyRefreshToken(refreshToken) if err != nil { return nil, fmt.Errorf("刷新令牌验证失败: %w", err) } // 将旧的刷新令牌加入黑名单 if err := m.BlacklistRefreshToken(refreshToken); err != nil { // 记录日志但继续执行 fmt.Printf("警告:将旧刷新令牌加入黑名单失败: %v\n", err) } // 生成新的令牌对 return m.GenerateTokenPair(claims.UserID) } // ==================== 辅助方法 ==================== // parseTokenWithoutBlacklistCheck 解析令牌但不检查黑名单 func (m *JWTManager) parseTokenWithoutBlacklistCheck(tokenString string) (*Claims, error) { // 尝试作为访问令牌解析 token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { return m.accessSecret, nil }) if err == nil && token.Valid { if claims, ok := token.Claims.(*Claims); ok { return claims, nil } } // 尝试作为刷新令牌解析 token, err = jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { return m.refreshSecret, nil }) if err != nil { return nil, err } if claims, ok := token.Claims.(*Claims); ok && token.Valid { return claims, nil } return nil, errors.New("无法解析令牌") } // hashToken 生成令牌的哈希值 func (m *JWTManager) hashToken(token string) string { hash := md5.Sum([]byte(token)) return hex.EncodeToString(hash[:]) } // getBlacklistKey 获取黑名单键名 func (m *JWTManager) getBlacklistKey(tokenHash string, tokenType TokenType) string { return fmt.Sprintf("%sblacklist:%s:%s", m.prefix, tokenType, tokenHash) } // getUserTokenSetKey 获取用户令牌集合键名 func (m *JWTManager) getUserTokenSetKey(userID uint64) string { return fmt.Sprintf("%suser:tokens:%d", m.prefix, userID) } // getUserTokenKey 获取用户令牌键名 func (m *JWTManager) getUserTokenKey(userID uint64, tokenHash string) string { return fmt.Sprintf("%suser:%d:token:%s", m.prefix, userID, tokenHash) } // ==================== 工具方法 ==================== // GetTokenInfo 获取令牌信息 func (m *JWTManager) GetTokenInfo(tokenString string) (map[string]interface{}, error) { claims, err := m.parseTokenWithoutBlacklistCheck(tokenString) if err != nil { return nil, err } info := map[string]interface{}{ "user_id": claims.UserID, "token_type": claims.Type, "issued_at": claims.IssuedAt.Time.Format(time.RFC3339), "expires_at": claims.ExpiresAt.Time.Format(time.RFC3339), "issuer": claims.Issuer, "subject": claims.Subject, "remaining": time.Until(claims.ExpiresAt.Time).String(), } // 检查是否在黑名单中 if m.redis != nil { blacklisted, _ := m.IsTokenBlacklisted(tokenString) info["blacklisted"] = blacklisted } return info, nil } // ValidateAndExtract 验证令牌并提取用户ID func (m *JWTManager) ValidateAndExtract(tokenString string) (uint64, error) { claims, err := m.VerifyAccessToken(tokenString) if err != nil { return 0, err } return claims.UserID, nil } // Cleanup 清理过期的令牌数据 func (m *JWTManager) Cleanup() error { if m.redis == nil { return nil } // Redis 会自动清理过期的键 return nil } // GetBlacklistStats 获取黑名单统计信息 func (m *JWTManager) GetBlacklistStats() (map[string]interface{}, error) { if m.redis == nil { return nil, errors.New("redis 客户端未初始化") } stats := make(map[string]interface{}) // 获取访问令牌黑名单数量 accessPattern := fmt.Sprintf("%sblacklist:access:*", m.prefix) accessKeys, err := m.redis.Keys(m.ctx, accessPattern).Result() if err == nil { stats["access_token_blacklist_count"] = len(accessKeys) } else { stats["access_token_blacklist_count"] = 0 } // 获取刷新令牌黑名单数量 refreshPattern := fmt.Sprintf("%sblacklist:refresh:*", m.prefix) refreshKeys, err := m.redis.Keys(m.ctx, refreshPattern).Result() if err == nil { stats["refresh_token_blacklist_count"] = len(refreshKeys) } else { stats["refresh_token_blacklist_count"] = 0 } return stats, nil } // Ping 检查 Redis 连接 func (m *JWTManager) Ping() error { if m.redis == nil { return errors.New("redis 客户端未初始化") } _, err := m.redis.Ping(m.ctx).Result() return err } // Close 关闭 Redis 连接 func (m *JWTManager) Close() error { if m.redis == nil { return nil } return m.redis.Close() }