158 lines
4.4 KiB
Go
158 lines
4.4 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/golang-jwt/jwt/v4"
|
||
"github.com/redis/go-redis/v9"
|
||
"github.com/saas-mingyang/mingyang-admin-common/config"
|
||
jwt2 "github.com/saas-mingyang/mingyang-admin-common/utils/jwt"
|
||
"github.com/zeromicro/go-zero/rest/enum"
|
||
"mingyang-admin-app-api/internal/types"
|
||
"mingyang-admin-app-rpc/appclient"
|
||
"mingyang-admin-app-rpc/types/app"
|
||
"mingyang-admin-app-rpc/utils"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// 定义上下文键类型
|
||
|
||
type AuthorityMiddleware struct {
|
||
Rds redis.UniversalClient
|
||
AppRpc appclient.App
|
||
}
|
||
|
||
// NewAuthorityMiddleware 创建认证中间件
|
||
func NewAuthorityMiddleware(appRpc appclient.App, rds redis.UniversalClient) *AuthorityMiddleware {
|
||
return &AuthorityMiddleware{
|
||
AppRpc: appRpc,
|
||
Rds: rds,
|
||
}
|
||
}
|
||
|
||
// writeResult 写入统一格式的响应
|
||
func writeResult(w http.ResponseWriter, statusCode int, result *types.BaseDataInfo) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(statusCode)
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// writeError 写入错误响应(简化版)
|
||
func writeError(w http.ResponseWriter, statusCode int, message string) {
|
||
writeResult(w, statusCode, &types.BaseDataInfo{
|
||
Msg: message,
|
||
Code: 500,
|
||
})
|
||
}
|
||
|
||
// Handle 中间件处理函数
|
||
func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
startTime := time.Now()
|
||
// 1. 从 Authorization Header 中提取 Bearer Token
|
||
authHeader := r.Header.Get("Authorization")
|
||
if authHeader == "" {
|
||
writeError(w, 401, "Authorization header is required")
|
||
return
|
||
}
|
||
fromToken := jwt2.StripBearerPrefixFromToken(r.Header.Get("Authorization"))
|
||
if m.Rds != nil {
|
||
cacheKey := config.RedisTokenPrefix + fromToken
|
||
fmt.Printf("cacheKey: %s\n", cacheKey)
|
||
cachedUserID, err := m.Rds.Get(r.Context(), cacheKey).Result()
|
||
if err == nil && cachedUserID != "" {
|
||
// 从缓存中获取到用户ID
|
||
ctx := context.WithValue(r.Context(), enum.UserIdRpcCtxKey, cachedUserID)
|
||
r = r.WithContext(ctx)
|
||
next(w, r)
|
||
return
|
||
}
|
||
}
|
||
// 5. 调用 RPC 验证 Token
|
||
token, err := m.AppRpc.AuthToken(r.Context(), &app.AuthReq{Token: fromToken})
|
||
if err != nil {
|
||
fmt.Printf("Error validating token: %v\n", err)
|
||
// 根据错误类型返回不同的错误信息
|
||
var jwtErr *jwt.ValidationError
|
||
if errors.As(err, &jwtErr) {
|
||
switch {
|
||
case jwtErr.Errors&jwt.ValidationErrorExpired != 0:
|
||
writeError(w, 401, "Token has expired")
|
||
case jwtErr.Errors&jwt.ValidationErrorMalformed != 0:
|
||
writeError(w, 401, "Invalid token format")
|
||
case jwtErr.Errors&jwt.ValidationErrorSignatureInvalid != 0:
|
||
writeError(w, 401, "Invalid token signature")
|
||
case jwtErr.Errors&jwt.ValidationErrorNotValidYet != 0:
|
||
writeError(w, 401, "Token not valid yet")
|
||
default:
|
||
writeError(w, 401, "Invalid token")
|
||
}
|
||
} else {
|
||
// 网络错误或其他错误
|
||
writeError(w, 500, "Token validation service unavailable")
|
||
}
|
||
return
|
||
}
|
||
|
||
// 6. 获取用户ID
|
||
id := token.UserId
|
||
|
||
// 7. 缓存到 Redis(可选)
|
||
if m.Rds != nil {
|
||
cacheKey := fmt.Sprintf("token:%s", fromToken)
|
||
// 设置缓存,过期时间30分钟
|
||
m.Rds.Set(r.Context(), cacheKey, id, 30*time.Minute)
|
||
}
|
||
|
||
// 创建新的上下文,包含 Token 和用户信息
|
||
ctx := r.Context()
|
||
|
||
// 构建请求信息
|
||
content := &utils.UserContext{
|
||
Token: fromToken,
|
||
UserID: token.UserId,
|
||
TokenClaims: token,
|
||
ClientIP: getClientIP(r),
|
||
UserAgent: r.UserAgent(),
|
||
}
|
||
// 修正:必须接收 context.WithValue 的返回值
|
||
newContent := context.WithValue(ctx, utils.UserContent, content)
|
||
// 将新上下文设置到请求中
|
||
r = r.WithContext(newContent)
|
||
|
||
// 调用下一个处理器
|
||
next(w, r)
|
||
|
||
// 9. 记录请求日志(可选)
|
||
fmt.Printf("[%s] %s - UserID: %d - Duration: %v\n",
|
||
time.Now().Format("2006-01-02 15:04:05"),
|
||
r.URL.Path,
|
||
id,
|
||
time.Since(startTime))
|
||
}
|
||
}
|
||
|
||
// 获取客户端 IP
|
||
func getClientIP(r *http.Request) string {
|
||
// 尝试从 X-Forwarded-For 获取
|
||
if forwarded := r.Header.Get("X-Forwarded-For"); forwarded != "" {
|
||
// 可能有多个 IP,取第一个
|
||
ips := strings.Split(forwarded, ",")
|
||
if len(ips) > 0 {
|
||
return strings.TrimSpace(ips[0])
|
||
}
|
||
}
|
||
|
||
// 从 RemoteAddr 获取
|
||
ip := r.RemoteAddr
|
||
if colonIndex := strings.LastIndex(ip, ":"); colonIndex != -1 {
|
||
ip = ip[:colonIndex]
|
||
}
|
||
|
||
return ip
|
||
}
|