package middleware import ( "context" "encoding/json" "errors" "fmt" "github.com/golang-jwt/jwt/v4" "github.com/redis/go-redis/v9" "github.com/saas-mingyang/mingyang-admin-common/config" jwt2 "github.com/saas-mingyang/mingyang-admin-common/utils/jwt" "github.com/zeromicro/go-zero/rest/enum" "mingyang-admin-app-api/internal/types" "mingyang-admin-app-rpc/appclient" "mingyang-admin-app-rpc/types/app" "mingyang-admin-app-rpc/utils" "net/http" "strings" "time" ) // 定义上下文键类型 type AuthorityMiddleware struct { Rds redis.UniversalClient AppRpc appclient.App } // NewAuthorityMiddleware 创建认证中间件 func NewAuthorityMiddleware(appRpc appclient.App, rds redis.UniversalClient) *AuthorityMiddleware { return &AuthorityMiddleware{ AppRpc: appRpc, Rds: rds, } } // writeResult 写入统一格式的响应 func writeResult(w http.ResponseWriter, statusCode int, result *types.BaseDataInfo) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(statusCode) json.NewEncoder(w).Encode(result) } // writeError 写入错误响应(简化版) func writeError(w http.ResponseWriter, statusCode int, message string) { writeResult(w, statusCode, &types.BaseDataInfo{ Msg: message, Code: 500, }) } // Handle 中间件处理函数 func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { startTime := time.Now() // 1. 从 Authorization Header 中提取 Bearer Token authHeader := r.Header.Get("Authorization") if authHeader == "" { writeError(w, 401, "Authorization header is required") return } fromToken := jwt2.StripBearerPrefixFromToken(r.Header.Get("Authorization")) if m.Rds != nil { cacheKey := config.RedisTokenPrefix + fromToken fmt.Printf("cacheKey: %s\n", cacheKey) cachedUserID, err := m.Rds.Get(r.Context(), cacheKey).Result() if err == nil && cachedUserID != "" { // 从缓存中获取到用户ID ctx := context.WithValue(r.Context(), enum.UserIdRpcCtxKey, cachedUserID) r = r.WithContext(ctx) next(w, r) return } } // 5. 调用 RPC 验证 Token token, err := m.AppRpc.AuthToken(r.Context(), &app.AuthReq{Token: fromToken}) if err != nil { fmt.Printf("Error validating token: %v\n", err) // 根据错误类型返回不同的错误信息 var jwtErr *jwt.ValidationError if errors.As(err, &jwtErr) { switch { case jwtErr.Errors&jwt.ValidationErrorExpired != 0: writeError(w, 401, "Token has expired") case jwtErr.Errors&jwt.ValidationErrorMalformed != 0: writeError(w, 401, "Invalid token format") case jwtErr.Errors&jwt.ValidationErrorSignatureInvalid != 0: writeError(w, 401, "Invalid token signature") case jwtErr.Errors&jwt.ValidationErrorNotValidYet != 0: writeError(w, 401, "Token not valid yet") default: writeError(w, 401, "Invalid token") } } else { // 网络错误或其他错误 writeError(w, 500, "Token validation service unavailable") } return } // 6. 获取用户ID id := token.UserId // 7. 缓存到 Redis(可选) if m.Rds != nil { cacheKey := fmt.Sprintf("token:%s", fromToken) // 设置缓存,过期时间30分钟 m.Rds.Set(r.Context(), cacheKey, id, 30*time.Minute) } // 创建新的上下文,包含 Token 和用户信息 ctx := r.Context() // 构建请求信息 content := &utils.UserContext{ Token: fromToken, UserID: token.UserId, TokenClaims: token, ClientIP: getClientIP(r), UserAgent: r.UserAgent(), } // 修正:必须接收 context.WithValue 的返回值 newContent := context.WithValue(ctx, utils.UserContent, content) // 将新上下文设置到请求中 r = r.WithContext(newContent) // 调用下一个处理器 next(w, r) // 9. 记录请求日志(可选) fmt.Printf("[%s] %s - UserID: %d - Duration: %v\n", time.Now().Format("2006-01-02 15:04:05"), r.URL.Path, id, time.Since(startTime)) } } // 获取客户端 IP func getClientIP(r *http.Request) string { // 尝试从 X-Forwarded-For 获取 if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" { // 可能有多个 IP,取第一个 ips := strings.Split(forwarded, ",") if len(ips) > 0 { return strings.TrimSpace(ips[0]) } } // 从 RemoteAddr 获取 ip := r.RemoteAddr if colonIndex := strings.LastIndex(ip, ":"); colonIndex != -1 { ip = ip[:colonIndex] } return ip }