590 lines
16 KiB
Go
590 lines
16 KiB
Go
package jwt_manager
|
||
|
||
import (
|
||
"context"
|
||
"crypto/md5"
|
||
"encoding/hex"
|
||
"errors"
|
||
"fmt"
|
||
"mingyang-admin-app-rpc/internal/config"
|
||
"time"
|
||
|
||
"github.com/golang-jwt/jwt/v5"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
// TokenType 令牌类型
|
||
type TokenType string
|
||
|
||
const (
|
||
AccessToken TokenType = "access"
|
||
RefreshToken TokenType = "refresh"
|
||
)
|
||
|
||
// JWTConfig JWT 配置
|
||
|
||
// Claims JWT Claims 结构体
|
||
type Claims struct {
|
||
UserID uint64 `json:"user_id"`
|
||
Type TokenType `json:"type"`
|
||
jwt.RegisteredClaims
|
||
}
|
||
|
||
// TokenPair 令牌对
|
||
type TokenPair struct {
|
||
AccessToken string `json:"access_token"`
|
||
RefreshToken string `json:"refresh_token"`
|
||
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
|
||
RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"`
|
||
TokenType string `json:"token_type"`
|
||
}
|
||
|
||
// JWTManager JWT 管理器
|
||
type JWTManager struct {
|
||
accessSecret []byte
|
||
refreshSecret []byte
|
||
accessExpiry time.Duration
|
||
refreshExpiry time.Duration
|
||
issuer string
|
||
redis redis.UniversalClient
|
||
prefix string // Redis 键前缀
|
||
ctx context.Context
|
||
}
|
||
|
||
// NewJWTManager 创建新的 JWT 管理器
|
||
func NewJWTManager(config *config.JWTConfig, redisClient redis.UniversalClient) *JWTManager {
|
||
if config.AccessTokenSecret == "" || config.RefreshTokenSecret == "" {
|
||
panic("JWT 配置错误")
|
||
}
|
||
return &JWTManager{
|
||
accessSecret: []byte(config.AccessTokenSecret),
|
||
refreshSecret: []byte(config.RefreshTokenSecret),
|
||
accessExpiry: config.AccessTokenExpiry,
|
||
refreshExpiry: config.RefreshTokenExpiry,
|
||
issuer: config.Issuer,
|
||
redis: redisClient,
|
||
prefix: "jwt:",
|
||
ctx: context.Background(),
|
||
}
|
||
}
|
||
|
||
// WithContext 设置上下文
|
||
func (m *JWTManager) WithContext(ctx context.Context) *JWTManager {
|
||
m.ctx = ctx
|
||
return m
|
||
}
|
||
|
||
// ==================== Token 生成方法 ====================
|
||
|
||
// GenerateTokenPair 生成访问令牌和刷新令牌对
|
||
func (m *JWTManager) GenerateTokenPair(userID uint64) (*TokenPair, error) {
|
||
// 生成访问令牌
|
||
accessToken, accessClaims, err := m.generateToken(userID, AccessToken, m.accessSecret, m.accessExpiry)
|
||
fmt.Printf("accessToken: %v\n", accessToken)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成访问令牌失败: %w", err)
|
||
}
|
||
|
||
// 生成刷新令牌
|
||
refreshToken, refreshClaims, err := m.generateToken(userID, RefreshToken, m.refreshSecret, m.refreshExpiry)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成刷新令牌失败: %w", err)
|
||
}
|
||
|
||
// 存储令牌到 Redis(用于追踪和管理)
|
||
if m.redis != nil {
|
||
m.storeToken(userID, accessToken, accessClaims.ExpiresAt.Time)
|
||
m.storeToken(userID, refreshToken, refreshClaims.ExpiresAt.Time)
|
||
}
|
||
|
||
return &TokenPair{
|
||
AccessToken: accessToken,
|
||
RefreshToken: refreshToken,
|
||
AccessTokenExpiresAt: accessClaims.ExpiresAt.Time,
|
||
RefreshTokenExpiresAt: refreshClaims.ExpiresAt.Time,
|
||
TokenType: "Bearer",
|
||
}, nil
|
||
}
|
||
|
||
// GenerateAccessToken 生成访问令牌
|
||
func (m *JWTManager) GenerateAccessToken(userID uint64) (string, *Claims, error) {
|
||
token, claims, err := m.generateToken(userID, AccessToken, m.accessSecret, m.accessExpiry)
|
||
if err != nil {
|
||
return "", nil, err
|
||
}
|
||
|
||
// 存储到 Redis
|
||
if m.redis != nil {
|
||
m.storeToken(userID, token, claims.ExpiresAt.Time)
|
||
}
|
||
|
||
return token, claims, nil
|
||
}
|
||
|
||
// GenerateRefreshToken 生成刷新令牌
|
||
func (m *JWTManager) GenerateRefreshToken(userID uint64) (string, *Claims, error) {
|
||
token, claims, err := m.generateToken(userID, RefreshToken, m.refreshSecret, m.refreshExpiry)
|
||
if err != nil {
|
||
return "", nil, err
|
||
}
|
||
|
||
// 存储到 Redis
|
||
if m.redis != nil {
|
||
m.storeToken(userID, token, claims.ExpiresAt.Time)
|
||
}
|
||
|
||
return token, claims, nil
|
||
}
|
||
|
||
// generateToken 内部方法:生成令牌
|
||
func (m *JWTManager) generateToken(userID uint64, tokenType TokenType, secret []byte, expiry time.Duration) (string, *Claims, error) {
|
||
now := time.Now()
|
||
expireAt := now.Add(expiry)
|
||
|
||
claims := &Claims{
|
||
UserID: userID,
|
||
Type: tokenType,
|
||
RegisteredClaims: jwt.RegisteredClaims{
|
||
ExpiresAt: jwt.NewNumericDate(expireAt),
|
||
IssuedAt: jwt.NewNumericDate(now),
|
||
Issuer: m.issuer,
|
||
Subject: fmt.Sprintf("%d", userID),
|
||
},
|
||
}
|
||
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||
tokenString, err := token.SignedString(secret)
|
||
if err != nil {
|
||
return "", nil, fmt.Errorf("签名令牌失败: %w", err)
|
||
}
|
||
|
||
return tokenString, claims, nil
|
||
}
|
||
|
||
// ==================== Token 验证方法 ====================
|
||
|
||
// VerifyAccessToken 验证访问令牌(包含黑名单检查)
|
||
func (m *JWTManager) VerifyAccessToken(tokenString string) (*Claims, error) {
|
||
return m.verifyToken(tokenString, AccessToken, m.accessSecret)
|
||
}
|
||
|
||
// VerifyRefreshToken 验证刷新令牌(包含黑名单检查)
|
||
func (m *JWTManager) VerifyRefreshToken(tokenString string) (*Claims, error) {
|
||
return m.verifyToken(tokenString, RefreshToken, m.refreshSecret)
|
||
}
|
||
|
||
// verifyToken 内部方法:验证令牌
|
||
func (m *JWTManager) verifyToken(tokenString string, expectedType TokenType, secret []byte) (*Claims, error) {
|
||
// 1. 检查令牌是否在黑名单中
|
||
if m.redis != nil {
|
||
blacklisted, err := m.isTokenBlacklisted(tokenString, expectedType)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("检查黑名单失败: %w", err)
|
||
}
|
||
if blacklisted {
|
||
return nil, errors.New("令牌已失效,请重新登录")
|
||
}
|
||
}
|
||
|
||
// 2. 验证 JWT
|
||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, fmt.Errorf("不支持的签名方法: %v", token.Header["alg"])
|
||
}
|
||
return secret, nil
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("解析令牌失败: %w", err)
|
||
}
|
||
|
||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||
// 3. 检查令牌类型
|
||
if claims.Type != expectedType {
|
||
return nil, errors.New("令牌类型不匹配")
|
||
}
|
||
|
||
// 4. 检查令牌是否过期
|
||
if claims.ExpiresAt.Time.Before(time.Now()) {
|
||
return nil, errors.New("令牌已过期")
|
||
}
|
||
|
||
return claims, nil
|
||
}
|
||
|
||
return nil, errors.New("无效的令牌")
|
||
}
|
||
|
||
// ==================== 黑名单管理 ====================
|
||
|
||
// BlacklistToken 将令牌加入黑名单
|
||
func (m *JWTManager) BlacklistToken(tokenString string) error {
|
||
if m.redis == nil {
|
||
return errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
// 尝试解析令牌
|
||
claims, err := m.parseTokenWithoutBlacklistCheck(tokenString)
|
||
if err != nil {
|
||
return fmt.Errorf("无法识别令牌: %w", err)
|
||
}
|
||
|
||
return m.addToBlacklist(tokenString, claims.Type, claims.ExpiresAt.Time)
|
||
}
|
||
|
||
// BlacklistAccessToken 将访问令牌加入黑名单
|
||
func (m *JWTManager) BlacklistAccessToken(tokenString string) error {
|
||
if m.redis == nil {
|
||
return errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
claims, err := m.VerifyAccessToken(tokenString)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的访问令牌: %w", err)
|
||
}
|
||
|
||
return m.addToBlacklist(tokenString, AccessToken, claims.ExpiresAt.Time)
|
||
}
|
||
|
||
// BlacklistRefreshToken 将刷新令牌加入黑名单
|
||
func (m *JWTManager) BlacklistRefreshToken(tokenString string) error {
|
||
if m.redis == nil {
|
||
return errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
claims, err := m.VerifyRefreshToken(tokenString)
|
||
if err != nil {
|
||
return fmt.Errorf("无效的刷新令牌: %w", err)
|
||
}
|
||
|
||
return m.addToBlacklist(tokenString, RefreshToken, claims.ExpiresAt.Time)
|
||
}
|
||
|
||
// addToBlacklist 内部方法:将令牌加入黑名单
|
||
func (m *JWTManager) addToBlacklist(tokenString string, tokenType TokenType, expireTime time.Time) error {
|
||
remaining := time.Until(expireTime)
|
||
if remaining <= 0 {
|
||
// 令牌已过期,不需要加入黑名单
|
||
return nil
|
||
}
|
||
|
||
tokenHash := m.hashToken(tokenString)
|
||
blacklistKey := m.getBlacklistKey(tokenHash, tokenType)
|
||
|
||
// 设置黑名单,过期时间等于令牌剩余有效期
|
||
err := m.redis.Set(m.ctx, blacklistKey, "1", remaining).Err()
|
||
if err != nil {
|
||
return fmt.Errorf("设置黑名单失败: %w", err)
|
||
}
|
||
|
||
// 记录用户与令牌的关联(用于按用户清理)
|
||
if claims, err := m.parseTokenWithoutBlacklistCheck(tokenString); err == nil {
|
||
userTokenKey := m.getUserTokenKey(claims.UserID, tokenHash)
|
||
m.redis.Set(m.ctx, userTokenKey, tokenString, remaining)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// IsTokenBlacklisted 检查令牌是否在黑名单中
|
||
func (m *JWTManager) IsTokenBlacklisted(tokenString string) (bool, error) {
|
||
if m.redis == nil {
|
||
return false, errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
// 检查所有类型的黑名单
|
||
tokenHash := m.hashToken(tokenString)
|
||
|
||
// 检查访问令牌黑名单
|
||
accessBlacklistKey := m.getBlacklistKey(tokenHash, AccessToken)
|
||
exists, err := m.redis.Exists(m.ctx, accessBlacklistKey).Result()
|
||
if err != nil {
|
||
return false, fmt.Errorf("检查访问令牌黑名单失败: %w", err)
|
||
}
|
||
if exists > 0 {
|
||
return true, nil
|
||
}
|
||
|
||
// 检查刷新令牌黑名单
|
||
refreshBlacklistKey := m.getBlacklistKey(tokenHash, RefreshToken)
|
||
exists, err = m.redis.Exists(m.ctx, refreshBlacklistKey).Result()
|
||
if err != nil {
|
||
return false, fmt.Errorf("检查刷新令牌黑名单失败: %w", err)
|
||
}
|
||
|
||
return exists > 0, nil
|
||
}
|
||
|
||
// isTokenBlacklisted 内部方法:检查令牌是否在黑名单中
|
||
func (m *JWTManager) isTokenBlacklisted(tokenString string, tokenType TokenType) (bool, error) {
|
||
if m.redis == nil {
|
||
return false, nil
|
||
}
|
||
|
||
tokenHash := m.hashToken(tokenString)
|
||
blacklistKey := m.getBlacklistKey(tokenHash, tokenType)
|
||
|
||
exists, err := m.redis.Exists(m.ctx, blacklistKey).Result()
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
|
||
return exists > 0, nil
|
||
}
|
||
|
||
// RemoveFromBlacklist 从黑名单中移除令牌
|
||
func (m *JWTManager) RemoveFromBlacklist(tokenString string) error {
|
||
if m.redis == nil {
|
||
return errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
tokenHash := m.hashToken(tokenString)
|
||
|
||
// 尝试移除两种类型的黑名单
|
||
accessBlacklistKey := m.getBlacklistKey(tokenHash, AccessToken)
|
||
refreshBlacklistKey := m.getBlacklistKey(tokenHash, RefreshToken)
|
||
|
||
err := m.redis.Del(m.ctx, accessBlacklistKey, refreshBlacklistKey).Err()
|
||
return err
|
||
}
|
||
|
||
// RevokeUserTokens 撤销用户所有令牌
|
||
func (m *JWTManager) RevokeUserTokens(userID uint64) error {
|
||
if m.redis == nil {
|
||
return errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
// 获取用户的所有令牌
|
||
tokens, err := m.GetUserTokens(userID)
|
||
if err != nil {
|
||
return fmt.Errorf("获取用户令牌失败: %w", err)
|
||
}
|
||
|
||
// 将所有令牌加入黑名单
|
||
for _, token := range tokens {
|
||
m.BlacklistToken(token)
|
||
}
|
||
|
||
// 清理用户令牌集合
|
||
userTokenSetKey := m.getUserTokenSetKey(userID)
|
||
err = m.redis.Del(m.ctx, userTokenSetKey).Err()
|
||
if err != nil {
|
||
return fmt.Errorf("清理用户令牌集合失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ==================== 用户令牌管理 ====================
|
||
|
||
// storeToken 存储用户令牌
|
||
func (m *JWTManager) storeToken(userID uint64, tokenString string, expireTime time.Time) {
|
||
if m.redis == nil {
|
||
return
|
||
}
|
||
|
||
// 将令牌添加到用户的令牌集合
|
||
userTokenSetKey := m.getUserTokenSetKey(userID)
|
||
err := m.redis.SAdd(m.ctx, userTokenSetKey, tokenString).Err()
|
||
if err != nil {
|
||
// 记录错误但继续执行
|
||
fmt.Printf("存储用户令牌失败: %v\n", err)
|
||
return
|
||
}
|
||
|
||
// 设置集合的过期时间(比令牌晚 1 小时)
|
||
expireAt := expireTime.Add(time.Hour)
|
||
err = m.redis.ExpireAt(m.ctx, userTokenSetKey, expireAt).Err()
|
||
if err != nil {
|
||
fmt.Printf("设置集合过期时间失败: %v\n", err)
|
||
}
|
||
}
|
||
|
||
// GetUserTokens 获取用户的所有令牌
|
||
func (m *JWTManager) GetUserTokens(userID uint64) ([]string, error) {
|
||
if m.redis == nil {
|
||
return nil, errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
userTokenSetKey := m.getUserTokenSetKey(userID)
|
||
return m.redis.SMembers(m.ctx, userTokenSetKey).Result()
|
||
}
|
||
|
||
// RemoveUserToken 移除用户的特定令牌
|
||
func (m *JWTManager) RemoveUserToken(userID uint64, tokenString string) error {
|
||
if m.redis == nil {
|
||
return errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
userTokenSetKey := m.getUserTokenSetKey(userID)
|
||
err := m.redis.SRem(m.ctx, userTokenSetKey, tokenString).Err()
|
||
return err
|
||
}
|
||
|
||
// ==================== 令牌刷新 ====================
|
||
|
||
// RefreshTokens 使用刷新令牌获取新的令牌对
|
||
func (m *JWTManager) RefreshTokens(refreshToken string) (*TokenPair, error) {
|
||
// 验证刷新令牌
|
||
claims, err := m.VerifyRefreshToken(refreshToken)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("刷新令牌验证失败: %w", err)
|
||
}
|
||
|
||
// 将旧的刷新令牌加入黑名单
|
||
if err := m.BlacklistRefreshToken(refreshToken); err != nil {
|
||
// 记录日志但继续执行
|
||
fmt.Printf("警告:将旧刷新令牌加入黑名单失败: %v\n", err)
|
||
}
|
||
|
||
// 生成新的令牌对
|
||
return m.GenerateTokenPair(claims.UserID)
|
||
}
|
||
|
||
// ==================== 辅助方法 ====================
|
||
|
||
// parseTokenWithoutBlacklistCheck 解析令牌但不检查黑名单
|
||
func (m *JWTManager) parseTokenWithoutBlacklistCheck(tokenString string) (*Claims, error) {
|
||
// 尝试作为访问令牌解析
|
||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||
return m.accessSecret, nil
|
||
})
|
||
|
||
if err == nil && token.Valid {
|
||
if claims, ok := token.Claims.(*Claims); ok {
|
||
return claims, nil
|
||
}
|
||
}
|
||
|
||
// 尝试作为刷新令牌解析
|
||
token, err = jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||
return m.refreshSecret, nil
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||
return claims, nil
|
||
}
|
||
|
||
return nil, errors.New("无法解析令牌")
|
||
}
|
||
|
||
// hashToken 生成令牌的哈希值
|
||
func (m *JWTManager) hashToken(token string) string {
|
||
hash := md5.Sum([]byte(token))
|
||
return hex.EncodeToString(hash[:])
|
||
}
|
||
|
||
// getBlacklistKey 获取黑名单键名
|
||
func (m *JWTManager) getBlacklistKey(tokenHash string, tokenType TokenType) string {
|
||
return fmt.Sprintf("%sblacklist:%s:%s", m.prefix, tokenType, tokenHash)
|
||
}
|
||
|
||
// getUserTokenSetKey 获取用户令牌集合键名
|
||
func (m *JWTManager) getUserTokenSetKey(userID uint64) string {
|
||
return fmt.Sprintf("%suser:tokens:%d", m.prefix, userID)
|
||
}
|
||
|
||
// getUserTokenKey 获取用户令牌键名
|
||
func (m *JWTManager) getUserTokenKey(userID uint64, tokenHash string) string {
|
||
return fmt.Sprintf("%suser:%d:token:%s", m.prefix, userID, tokenHash)
|
||
}
|
||
|
||
// ==================== 工具方法 ====================
|
||
|
||
// GetTokenInfo 获取令牌信息
|
||
func (m *JWTManager) GetTokenInfo(tokenString string) (map[string]interface{}, error) {
|
||
claims, err := m.parseTokenWithoutBlacklistCheck(tokenString)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
info := map[string]interface{}{
|
||
"user_id": claims.UserID,
|
||
"token_type": claims.Type,
|
||
"issued_at": claims.IssuedAt.Time.Format(time.RFC3339),
|
||
"expires_at": claims.ExpiresAt.Time.Format(time.RFC3339),
|
||
"issuer": claims.Issuer,
|
||
"subject": claims.Subject,
|
||
"remaining": time.Until(claims.ExpiresAt.Time).String(),
|
||
}
|
||
|
||
// 检查是否在黑名单中
|
||
if m.redis != nil {
|
||
blacklisted, _ := m.IsTokenBlacklisted(tokenString)
|
||
info["blacklisted"] = blacklisted
|
||
}
|
||
|
||
return info, nil
|
||
}
|
||
|
||
// ValidateAndExtract 验证令牌并提取用户ID
|
||
func (m *JWTManager) ValidateAndExtract(tokenString string) (uint64, error) {
|
||
claims, err := m.VerifyAccessToken(tokenString)
|
||
if err != nil {
|
||
return 0, err
|
||
}
|
||
return claims.UserID, nil
|
||
}
|
||
|
||
// Cleanup 清理过期的令牌数据
|
||
func (m *JWTManager) Cleanup() error {
|
||
if m.redis == nil {
|
||
return nil
|
||
}
|
||
|
||
// Redis 会自动清理过期的键
|
||
return nil
|
||
}
|
||
|
||
// GetBlacklistStats 获取黑名单统计信息
|
||
func (m *JWTManager) GetBlacklistStats() (map[string]interface{}, error) {
|
||
if m.redis == nil {
|
||
return nil, errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
stats := make(map[string]interface{})
|
||
|
||
// 获取访问令牌黑名单数量
|
||
accessPattern := fmt.Sprintf("%sblacklist:access:*", m.prefix)
|
||
accessKeys, err := m.redis.Keys(m.ctx, accessPattern).Result()
|
||
if err == nil {
|
||
stats["access_token_blacklist_count"] = len(accessKeys)
|
||
} else {
|
||
stats["access_token_blacklist_count"] = 0
|
||
}
|
||
|
||
// 获取刷新令牌黑名单数量
|
||
refreshPattern := fmt.Sprintf("%sblacklist:refresh:*", m.prefix)
|
||
refreshKeys, err := m.redis.Keys(m.ctx, refreshPattern).Result()
|
||
if err == nil {
|
||
stats["refresh_token_blacklist_count"] = len(refreshKeys)
|
||
} else {
|
||
stats["refresh_token_blacklist_count"] = 0
|
||
}
|
||
|
||
return stats, nil
|
||
}
|
||
|
||
// Ping 检查 Redis 连接
|
||
func (m *JWTManager) Ping() error {
|
||
if m.redis == nil {
|
||
return errors.New("redis 客户端未初始化")
|
||
}
|
||
|
||
_, err := m.redis.Ping(m.ctx).Result()
|
||
return err
|
||
}
|
||
|
||
// Close 关闭 Redis 连接
|
||
func (m *JWTManager) Close() error {
|
||
if m.redis == nil {
|
||
return nil
|
||
}
|
||
|
||
return m.redis.Close()
|
||
}
|