mingyang-admin-iot-app/rpc/internal/jwt_manager/jwt_manager.go

590 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package jwt_manager
import (
"context"
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"mingyang-admin-app-rpc/internal/config"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/redis/go-redis/v9"
)
// TokenType 令牌类型
type TokenType string
const (
AccessToken TokenType = "access"
RefreshToken TokenType = "refresh"
)
// JWTConfig JWT 配置
// Claims JWT Claims 结构体
type Claims struct {
UserID uint64 `json:"user_id"`
Type TokenType `json:"type"`
jwt.RegisteredClaims
}
// TokenPair 令牌对
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
AccessTokenExpiresAt time.Time `json:"access_token_expires_at"`
RefreshTokenExpiresAt time.Time `json:"refresh_token_expires_at"`
TokenType string `json:"token_type"`
}
// JWTManager JWT 管理器
type JWTManager struct {
accessSecret []byte
refreshSecret []byte
accessExpiry time.Duration
refreshExpiry time.Duration
issuer string
redis redis.UniversalClient
prefix string // Redis 键前缀
ctx context.Context
}
// NewJWTManager 创建新的 JWT 管理器
func NewJWTManager(config *config.JWTConfig, redisClient redis.UniversalClient) *JWTManager {
if config.AccessTokenSecret == "" || config.RefreshTokenSecret == "" {
panic("JWT 配置错误")
}
return &JWTManager{
accessSecret: []byte(config.AccessTokenSecret),
refreshSecret: []byte(config.RefreshTokenSecret),
accessExpiry: config.AccessTokenExpiry,
refreshExpiry: config.RefreshTokenExpiry,
issuer: config.Issuer,
redis: redisClient,
prefix: "jwt:",
ctx: context.Background(),
}
}
// WithContext 设置上下文
func (m *JWTManager) WithContext(ctx context.Context) *JWTManager {
m.ctx = ctx
return m
}
// ==================== Token 生成方法 ====================
// GenerateTokenPair 生成访问令牌和刷新令牌对
func (m *JWTManager) GenerateTokenPair(userID uint64) (*TokenPair, error) {
// 生成访问令牌
accessToken, accessClaims, err := m.generateToken(userID, AccessToken, m.accessSecret, m.accessExpiry)
fmt.Printf("accessToken: %v\n", accessToken)
if err != nil {
return nil, fmt.Errorf("生成访问令牌失败: %w", err)
}
// 生成刷新令牌
refreshToken, refreshClaims, err := m.generateToken(userID, RefreshToken, m.refreshSecret, m.refreshExpiry)
if err != nil {
return nil, fmt.Errorf("生成刷新令牌失败: %w", err)
}
// 存储令牌到 Redis用于追踪和管理
if m.redis != nil {
m.storeToken(userID, accessToken, accessClaims.ExpiresAt.Time)
m.storeToken(userID, refreshToken, refreshClaims.ExpiresAt.Time)
}
return &TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
AccessTokenExpiresAt: accessClaims.ExpiresAt.Time,
RefreshTokenExpiresAt: refreshClaims.ExpiresAt.Time,
TokenType: "Bearer",
}, nil
}
// GenerateAccessToken 生成访问令牌
func (m *JWTManager) GenerateAccessToken(userID uint64) (string, *Claims, error) {
token, claims, err := m.generateToken(userID, AccessToken, m.accessSecret, m.accessExpiry)
if err != nil {
return "", nil, err
}
// 存储到 Redis
if m.redis != nil {
m.storeToken(userID, token, claims.ExpiresAt.Time)
}
return token, claims, nil
}
// GenerateRefreshToken 生成刷新令牌
func (m *JWTManager) GenerateRefreshToken(userID uint64) (string, *Claims, error) {
token, claims, err := m.generateToken(userID, RefreshToken, m.refreshSecret, m.refreshExpiry)
if err != nil {
return "", nil, err
}
// 存储到 Redis
if m.redis != nil {
m.storeToken(userID, token, claims.ExpiresAt.Time)
}
return token, claims, nil
}
// generateToken 内部方法:生成令牌
func (m *JWTManager) generateToken(userID uint64, tokenType TokenType, secret []byte, expiry time.Duration) (string, *Claims, error) {
now := time.Now()
expireAt := now.Add(expiry)
claims := &Claims{
UserID: userID,
Type: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expireAt),
IssuedAt: jwt.NewNumericDate(now),
Issuer: m.issuer,
Subject: fmt.Sprintf("%d", userID),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(secret)
if err != nil {
return "", nil, fmt.Errorf("签名令牌失败: %w", err)
}
return tokenString, claims, nil
}
// ==================== Token 验证方法 ====================
// VerifyAccessToken 验证访问令牌(包含黑名单检查)
func (m *JWTManager) VerifyAccessToken(tokenString string) (*Claims, error) {
return m.verifyToken(tokenString, AccessToken, m.accessSecret)
}
// VerifyRefreshToken 验证刷新令牌(包含黑名单检查)
func (m *JWTManager) VerifyRefreshToken(tokenString string) (*Claims, error) {
return m.verifyToken(tokenString, RefreshToken, m.refreshSecret)
}
// verifyToken 内部方法:验证令牌
func (m *JWTManager) verifyToken(tokenString string, expectedType TokenType, secret []byte) (*Claims, error) {
// 1. 检查令牌是否在黑名单中
if m.redis != nil {
blacklisted, err := m.isTokenBlacklisted(tokenString, expectedType)
if err != nil {
return nil, fmt.Errorf("检查黑名单失败: %w", err)
}
if blacklisted {
return nil, errors.New("令牌已失效,请重新登录")
}
}
// 2. 验证 JWT
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("不支持的签名方法: %v", token.Header["alg"])
}
return secret, nil
})
if err != nil {
return nil, fmt.Errorf("解析令牌失败: %w", err)
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 3. 检查令牌类型
if claims.Type != expectedType {
return nil, errors.New("令牌类型不匹配")
}
// 4. 检查令牌是否过期
if claims.ExpiresAt.Time.Before(time.Now()) {
return nil, errors.New("令牌已过期")
}
return claims, nil
}
return nil, errors.New("无效的令牌")
}
// ==================== 黑名单管理 ====================
// BlacklistToken 将令牌加入黑名单
func (m *JWTManager) BlacklistToken(tokenString string) error {
if m.redis == nil {
return errors.New("redis 客户端未初始化")
}
// 尝试解析令牌
claims, err := m.parseTokenWithoutBlacklistCheck(tokenString)
if err != nil {
return fmt.Errorf("无法识别令牌: %w", err)
}
return m.addToBlacklist(tokenString, claims.Type, claims.ExpiresAt.Time)
}
// BlacklistAccessToken 将访问令牌加入黑名单
func (m *JWTManager) BlacklistAccessToken(tokenString string) error {
if m.redis == nil {
return errors.New("redis 客户端未初始化")
}
claims, err := m.VerifyAccessToken(tokenString)
if err != nil {
return fmt.Errorf("无效的访问令牌: %w", err)
}
return m.addToBlacklist(tokenString, AccessToken, claims.ExpiresAt.Time)
}
// BlacklistRefreshToken 将刷新令牌加入黑名单
func (m *JWTManager) BlacklistRefreshToken(tokenString string) error {
if m.redis == nil {
return errors.New("redis 客户端未初始化")
}
claims, err := m.VerifyRefreshToken(tokenString)
if err != nil {
return fmt.Errorf("无效的刷新令牌: %w", err)
}
return m.addToBlacklist(tokenString, RefreshToken, claims.ExpiresAt.Time)
}
// addToBlacklist 内部方法:将令牌加入黑名单
func (m *JWTManager) addToBlacklist(tokenString string, tokenType TokenType, expireTime time.Time) error {
remaining := time.Until(expireTime)
if remaining <= 0 {
// 令牌已过期,不需要加入黑名单
return nil
}
tokenHash := m.hashToken(tokenString)
blacklistKey := m.getBlacklistKey(tokenHash, tokenType)
// 设置黑名单,过期时间等于令牌剩余有效期
err := m.redis.Set(m.ctx, blacklistKey, "1", remaining).Err()
if err != nil {
return fmt.Errorf("设置黑名单失败: %w", err)
}
// 记录用户与令牌的关联(用于按用户清理)
if claims, err := m.parseTokenWithoutBlacklistCheck(tokenString); err == nil {
userTokenKey := m.getUserTokenKey(claims.UserID, tokenHash)
m.redis.Set(m.ctx, userTokenKey, tokenString, remaining)
}
return nil
}
// IsTokenBlacklisted 检查令牌是否在黑名单中
func (m *JWTManager) IsTokenBlacklisted(tokenString string) (bool, error) {
if m.redis == nil {
return false, errors.New("redis 客户端未初始化")
}
// 检查所有类型的黑名单
tokenHash := m.hashToken(tokenString)
// 检查访问令牌黑名单
accessBlacklistKey := m.getBlacklistKey(tokenHash, AccessToken)
exists, err := m.redis.Exists(m.ctx, accessBlacklistKey).Result()
if err != nil {
return false, fmt.Errorf("检查访问令牌黑名单失败: %w", err)
}
if exists > 0 {
return true, nil
}
// 检查刷新令牌黑名单
refreshBlacklistKey := m.getBlacklistKey(tokenHash, RefreshToken)
exists, err = m.redis.Exists(m.ctx, refreshBlacklistKey).Result()
if err != nil {
return false, fmt.Errorf("检查刷新令牌黑名单失败: %w", err)
}
return exists > 0, nil
}
// isTokenBlacklisted 内部方法:检查令牌是否在黑名单中
func (m *JWTManager) isTokenBlacklisted(tokenString string, tokenType TokenType) (bool, error) {
if m.redis == nil {
return false, nil
}
tokenHash := m.hashToken(tokenString)
blacklistKey := m.getBlacklistKey(tokenHash, tokenType)
exists, err := m.redis.Exists(m.ctx, blacklistKey).Result()
if err != nil {
return false, err
}
return exists > 0, nil
}
// RemoveFromBlacklist 从黑名单中移除令牌
func (m *JWTManager) RemoveFromBlacklist(tokenString string) error {
if m.redis == nil {
return errors.New("redis 客户端未初始化")
}
tokenHash := m.hashToken(tokenString)
// 尝试移除两种类型的黑名单
accessBlacklistKey := m.getBlacklistKey(tokenHash, AccessToken)
refreshBlacklistKey := m.getBlacklistKey(tokenHash, RefreshToken)
err := m.redis.Del(m.ctx, accessBlacklistKey, refreshBlacklistKey).Err()
return err
}
// RevokeUserTokens 撤销用户所有令牌
func (m *JWTManager) RevokeUserTokens(userID uint64) error {
if m.redis == nil {
return errors.New("redis 客户端未初始化")
}
// 获取用户的所有令牌
tokens, err := m.GetUserTokens(userID)
if err != nil {
return fmt.Errorf("获取用户令牌失败: %w", err)
}
// 将所有令牌加入黑名单
for _, token := range tokens {
m.BlacklistToken(token)
}
// 清理用户令牌集合
userTokenSetKey := m.getUserTokenSetKey(userID)
err = m.redis.Del(m.ctx, userTokenSetKey).Err()
if err != nil {
return fmt.Errorf("清理用户令牌集合失败: %w", err)
}
return nil
}
// ==================== 用户令牌管理 ====================
// storeToken 存储用户令牌
func (m *JWTManager) storeToken(userID uint64, tokenString string, expireTime time.Time) {
if m.redis == nil {
return
}
// 将令牌添加到用户的令牌集合
userTokenSetKey := m.getUserTokenSetKey(userID)
err := m.redis.SAdd(m.ctx, userTokenSetKey, tokenString).Err()
if err != nil {
// 记录错误但继续执行
fmt.Printf("存储用户令牌失败: %v\n", err)
return
}
// 设置集合的过期时间(比令牌晚 1 小时)
expireAt := expireTime.Add(time.Hour)
err = m.redis.ExpireAt(m.ctx, userTokenSetKey, expireAt).Err()
if err != nil {
fmt.Printf("设置集合过期时间失败: %v\n", err)
}
}
// GetUserTokens 获取用户的所有令牌
func (m *JWTManager) GetUserTokens(userID uint64) ([]string, error) {
if m.redis == nil {
return nil, errors.New("redis 客户端未初始化")
}
userTokenSetKey := m.getUserTokenSetKey(userID)
return m.redis.SMembers(m.ctx, userTokenSetKey).Result()
}
// RemoveUserToken 移除用户的特定令牌
func (m *JWTManager) RemoveUserToken(userID uint64, tokenString string) error {
if m.redis == nil {
return errors.New("redis 客户端未初始化")
}
userTokenSetKey := m.getUserTokenSetKey(userID)
err := m.redis.SRem(m.ctx, userTokenSetKey, tokenString).Err()
return err
}
// ==================== 令牌刷新 ====================
// RefreshTokens 使用刷新令牌获取新的令牌对
func (m *JWTManager) RefreshTokens(refreshToken string) (*TokenPair, error) {
// 验证刷新令牌
claims, err := m.VerifyRefreshToken(refreshToken)
if err != nil {
return nil, fmt.Errorf("刷新令牌验证失败: %w", err)
}
// 将旧的刷新令牌加入黑名单
if err := m.BlacklistRefreshToken(refreshToken); err != nil {
// 记录日志但继续执行
fmt.Printf("警告:将旧刷新令牌加入黑名单失败: %v\n", err)
}
// 生成新的令牌对
return m.GenerateTokenPair(claims.UserID)
}
// ==================== 辅助方法 ====================
// parseTokenWithoutBlacklistCheck 解析令牌但不检查黑名单
func (m *JWTManager) parseTokenWithoutBlacklistCheck(tokenString string) (*Claims, error) {
// 尝试作为访问令牌解析
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return m.accessSecret, nil
})
if err == nil && token.Valid {
if claims, ok := token.Claims.(*Claims); ok {
return claims, nil
}
}
// 尝试作为刷新令牌解析
token, err = jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
return m.refreshSecret, nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, errors.New("无法解析令牌")
}
// hashToken 生成令牌的哈希值
func (m *JWTManager) hashToken(token string) string {
hash := md5.Sum([]byte(token))
return hex.EncodeToString(hash[:])
}
// getBlacklistKey 获取黑名单键名
func (m *JWTManager) getBlacklistKey(tokenHash string, tokenType TokenType) string {
return fmt.Sprintf("%sblacklist:%s:%s", m.prefix, tokenType, tokenHash)
}
// getUserTokenSetKey 获取用户令牌集合键名
func (m *JWTManager) getUserTokenSetKey(userID uint64) string {
return fmt.Sprintf("%suser:tokens:%d", m.prefix, userID)
}
// getUserTokenKey 获取用户令牌键名
func (m *JWTManager) getUserTokenKey(userID uint64, tokenHash string) string {
return fmt.Sprintf("%suser:%d:token:%s", m.prefix, userID, tokenHash)
}
// ==================== 工具方法 ====================
// GetTokenInfo 获取令牌信息
func (m *JWTManager) GetTokenInfo(tokenString string) (map[string]interface{}, error) {
claims, err := m.parseTokenWithoutBlacklistCheck(tokenString)
if err != nil {
return nil, err
}
info := map[string]interface{}{
"user_id": claims.UserID,
"token_type": claims.Type,
"issued_at": claims.IssuedAt.Time.Format(time.RFC3339),
"expires_at": claims.ExpiresAt.Time.Format(time.RFC3339),
"issuer": claims.Issuer,
"subject": claims.Subject,
"remaining": time.Until(claims.ExpiresAt.Time).String(),
}
// 检查是否在黑名单中
if m.redis != nil {
blacklisted, _ := m.IsTokenBlacklisted(tokenString)
info["blacklisted"] = blacklisted
}
return info, nil
}
// ValidateAndExtract 验证令牌并提取用户ID
func (m *JWTManager) ValidateAndExtract(tokenString string) (uint64, error) {
claims, err := m.VerifyAccessToken(tokenString)
if err != nil {
return 0, err
}
return claims.UserID, nil
}
// Cleanup 清理过期的令牌数据
func (m *JWTManager) Cleanup() error {
if m.redis == nil {
return nil
}
// Redis 会自动清理过期的键
return nil
}
// GetBlacklistStats 获取黑名单统计信息
func (m *JWTManager) GetBlacklistStats() (map[string]interface{}, error) {
if m.redis == nil {
return nil, errors.New("redis 客户端未初始化")
}
stats := make(map[string]interface{})
// 获取访问令牌黑名单数量
accessPattern := fmt.Sprintf("%sblacklist:access:*", m.prefix)
accessKeys, err := m.redis.Keys(m.ctx, accessPattern).Result()
if err == nil {
stats["access_token_blacklist_count"] = len(accessKeys)
} else {
stats["access_token_blacklist_count"] = 0
}
// 获取刷新令牌黑名单数量
refreshPattern := fmt.Sprintf("%sblacklist:refresh:*", m.prefix)
refreshKeys, err := m.redis.Keys(m.ctx, refreshPattern).Result()
if err == nil {
stats["refresh_token_blacklist_count"] = len(refreshKeys)
} else {
stats["refresh_token_blacklist_count"] = 0
}
return stats, nil
}
// Ping 检查 Redis 连接
func (m *JWTManager) Ping() error {
if m.redis == nil {
return errors.New("redis 客户端未初始化")
}
_, err := m.redis.Ping(m.ctx).Result()
return err
}
// Close 关闭 Redis 连接
func (m *JWTManager) Close() error {
if m.redis == nil {
return nil
}
return m.redis.Close()
}