mingyang-admin-iot-app/api/internal/middleware/authority_middleware.go

158 lines
4.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/redis/go-redis/v9"
"github.com/saas-mingyang/mingyang-admin-common/config"
jwt2 "github.com/saas-mingyang/mingyang-admin-common/utils/jwt"
"github.com/zeromicro/go-zero/rest/enum"
"mingyang-admin-app-api/internal/types"
"mingyang-admin-app-rpc/appclient"
"mingyang-admin-app-rpc/types/app"
"net/http"
"strings"
"time"
)
// 定义上下文键类型
type AuthorityMiddleware struct {
Rds redis.UniversalClient
AppRpc appclient.App
}
// NewAuthorityMiddleware 创建认证中间件
func NewAuthorityMiddleware(appRpc appclient.App, rds redis.UniversalClient) *AuthorityMiddleware {
return &AuthorityMiddleware{
AppRpc: appRpc,
Rds: rds,
}
}
// writeResult 写入统一格式的响应
func writeResult(w http.ResponseWriter, statusCode int, result *types.BaseDataInfo) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(result)
}
// writeError 写入错误响应(简化版)
func writeError(w http.ResponseWriter, statusCode int, message string) {
writeResult(w, statusCode, &types.BaseDataInfo{
Msg: message,
Code: 500,
})
}
// Handle 中间件处理函数
func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
// 1. 从 Authorization Header 中提取 Bearer Token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeError(w, 401, "Authorization header is required")
return
}
fromToken := jwt2.StripBearerPrefixFromToken(r.Header.Get("Authorization"))
if m.Rds != nil {
cacheKey := config.RedisTokenPrefix + fromToken
fmt.Printf("cacheKey: %s\n", cacheKey)
cachedUserID, err := m.Rds.Get(r.Context(), cacheKey).Result()
if err == nil && cachedUserID != "" {
// 从缓存中获取到用户ID
ctx := context.WithValue(r.Context(), enum.UserIdRpcCtxKey, cachedUserID)
r = r.WithContext(ctx)
next(w, r)
return
}
}
// 5. 调用 RPC 验证 Token
token, err := m.AppRpc.AuthToken(r.Context(), &app.AuthReq{Token: fromToken})
if err != nil {
fmt.Printf("Error validating token: %v\n", err)
// 根据错误类型返回不同的错误信息
var jwtErr *jwt.ValidationError
if errors.As(err, &jwtErr) {
switch {
case jwtErr.Errors&jwt.ValidationErrorExpired != 0:
writeError(w, 401, "Token has expired")
case jwtErr.Errors&jwt.ValidationErrorMalformed != 0:
writeError(w, 401, "Invalid token format")
case jwtErr.Errors&jwt.ValidationErrorSignatureInvalid != 0:
writeError(w, 401, "Invalid token signature")
case jwtErr.Errors&jwt.ValidationErrorNotValidYet != 0:
writeError(w, 401, "Token not valid yet")
default:
writeError(w, 401, "Invalid token")
}
} else {
// 网络错误或其他错误
writeError(w, 500, "Token validation service unavailable")
}
return
}
// 6. 获取用户ID
id := token.UserId
// 7. 缓存到 Redis可选
if m.Rds != nil {
cacheKey := fmt.Sprintf("token:%s", fromToken)
// 设置缓存过期时间30分钟
m.Rds.Set(r.Context(), cacheKey, id, 30*time.Minute)
}
// 创建新的上下文,包含 Token 和用户信息
ctx := r.Context()
ctx = context.WithValue(ctx, "token", fromToken)
ctx = context.WithValue(ctx, "userId", token.UserId)
ctx = context.WithValue(ctx, "tokenClaims", token)
// 获取客户端 IP
clientIP := getClientIP(r)
ctx = context.WithValue(ctx, "client_ip", clientIP)
// 获取 User-Agent
userAgent := r.UserAgent()
ctx = context.WithValue(ctx, "user_agent", userAgent)
// 将新上下文设置到请求中
r = r.WithContext(ctx)
// 调用下一个处理器
next(w, r)
// 9. 记录请求日志(可选)
fmt.Printf("[%s] %s - UserID: %d - Duration: %v\n",
time.Now().Format("2006-01-02 15:04:05"),
r.URL.Path,
id,
time.Since(startTime))
}
}
// 获取客户端 IP
func getClientIP(r *http.Request) string {
// 尝试从 X-Forwarded-For 获取
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
// 可能有多个 IP取第一个
ips := strings.Split(forwarded, ",")
if len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
// 从 RemoteAddr 获取
ip := r.RemoteAddr
if colonIndex := strings.LastIndex(ip, ":"); colonIndex != -1 {
ip = ip[:colonIndex]
}
return ip
}