package middleware import ( "context" "encoding/json" "errors" "fmt" "github.com/golang-jwt/jwt/v4" "github.com/redis/go-redis/v9" "mingyang-admin-app-api/internal/types" "mingyang-admin-app-rpc/appclient" "mingyang-admin-app-rpc/types/app" "net/http" "strings" "time" ) // 定义上下文键类型 type contextKey string const ( UserIDKey contextKey = "user_id" UserInfoKey contextKey = "user_info" ) type AuthorityMiddleware struct { Rds redis.UniversalClient AppRpc appclient.App } // NewAuthorityMiddleware 创建认证中间件 func NewAuthorityMiddleware(appRpc appclient.App, rds redis.UniversalClient) *AuthorityMiddleware { return &AuthorityMiddleware{ AppRpc: appRpc, Rds: rds, } } // writeResult 写入统一格式的响应 func writeResult(w http.ResponseWriter, statusCode int, result *types.BaseDataInfo) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) json.NewEncoder(w).Encode(result) } // writeError 写入错误响应(简化版) func writeError(w http.ResponseWriter, statusCode int, message string) { writeResult(w, statusCode, &types.BaseDataInfo{ Msg: message, Code: 500, }) } // Handle 中间件处理函数 func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { startTime := time.Now() // 1. 从 Authorization Header 中提取 Bearer Token authHeader := r.Header.Get("Authorization") if authHeader == "" { writeError(w, 401, "Authorization header is required") return } // 2. 验证 Bearer Token 格式 if !strings.HasPrefix(authHeader, "Bearer ") { writeError(w, 401, "Invalid authorization format, must be 'Bearer '") return } // 3. 提取 Token tokenString := strings.TrimPrefix(authHeader, "Bearer ") if tokenString == "" { writeError(w, 401, "Token cannot be empty") return } // 4. 检查 Redis 缓存(可选) if m.Rds != nil { cacheKey := fmt.Sprintf("token:%s", tokenString) cachedUserID, err := m.Rds.Get(r.Context(), cacheKey).Result() if err == nil && cachedUserID != "" { // 从缓存中获取到用户ID ctx := context.WithValue(r.Context(), UserIDKey, cachedUserID) r = r.WithContext(ctx) next(w, r) return } } // 5. 调用 RPC 验证 Token fmt.Printf("Validating token: %s\n", tokenString) token, err := m.AppRpc.AuthToken(r.Context(), &app.AuthReq{Token: tokenString}) if err != nil { fmt.Printf("Error validating token: %v\n", err) // 根据错误类型返回不同的错误信息 var jwtErr *jwt.ValidationError if errors.As(err, &jwtErr) { switch { case jwtErr.Errors&jwt.ValidationErrorExpired != 0: writeError(w, 401, "Token has expired") case jwtErr.Errors&jwt.ValidationErrorMalformed != 0: writeError(w, 401, "Invalid token format") case jwtErr.Errors&jwt.ValidationErrorSignatureInvalid != 0: writeError(w, 401, "Invalid token signature") case jwtErr.Errors&jwt.ValidationErrorNotValidYet != 0: writeError(w, 401, "Token not valid yet") default: writeError(w, 401, "Invalid token") } } else { // 网络错误或其他错误 writeError(w, 500, "Token validation service unavailable") } return } // 6. 获取用户ID id := token.UserId // 7. 缓存到 Redis(可选) if m.Rds != nil { cacheKey := fmt.Sprintf("token:%s", tokenString) // 设置缓存,过期时间30分钟 m.Rds.Set(r.Context(), cacheKey, id, 30*time.Minute) } // 8. 设置到上下文 ctx := r.Context() ctx = context.WithValue(ctx, UserIDKey, id) ctx = context.WithValue(ctx, UserInfoKey, token) // 9. 记录请求日志(可选) fmt.Printf("[%s] %s - UserID: %d - Duration: %v\n", time.Now().Format("2006-01-02 15:04:05"), r.URL.Path, id, time.Since(startTime)) // 10. 继续处理请求 r = r.WithContext(ctx) next(w, r) } }