145 lines
3.9 KiB
Go
145 lines
3.9 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/golang-jwt/jwt/v4"
|
||
"github.com/redis/go-redis/v9"
|
||
"mingyang-admin-app-api/internal/types"
|
||
"mingyang-admin-app-rpc/appclient"
|
||
"mingyang-admin-app-rpc/types/app"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// 定义上下文键类型
|
||
type contextKey string
|
||
|
||
const (
|
||
UserIDKey contextKey = "user_id"
|
||
UserInfoKey contextKey = "user_info"
|
||
)
|
||
|
||
type AuthorityMiddleware struct {
|
||
Rds redis.UniversalClient
|
||
AppRpc appclient.App
|
||
}
|
||
|
||
// NewAuthorityMiddleware 创建认证中间件
|
||
func NewAuthorityMiddleware(appRpc appclient.App, rds redis.UniversalClient) *AuthorityMiddleware {
|
||
return &AuthorityMiddleware{
|
||
AppRpc: appRpc,
|
||
Rds: rds,
|
||
}
|
||
}
|
||
|
||
// writeResult 写入统一格式的响应
|
||
func writeResult(w http.ResponseWriter, statusCode int, result *types.BaseDataInfo) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(statusCode)
|
||
json.NewEncoder(w).Encode(result)
|
||
}
|
||
|
||
// writeError 写入错误响应(简化版)
|
||
func writeError(w http.ResponseWriter, statusCode int, message string) {
|
||
writeResult(w, statusCode, &types.BaseDataInfo{
|
||
Msg: message,
|
||
Code: 500,
|
||
})
|
||
}
|
||
|
||
// Handle 中间件处理函数
|
||
func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
startTime := time.Now()
|
||
// 1. 从 Authorization Header 中提取 Bearer Token
|
||
authHeader := r.Header.Get("Authorization")
|
||
if authHeader == "" {
|
||
writeError(w, 401, "Authorization header is required")
|
||
return
|
||
}
|
||
|
||
// 2. 验证 Bearer Token 格式
|
||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||
writeError(w, 401, "Invalid authorization format, must be 'Bearer <token>'")
|
||
return
|
||
}
|
||
|
||
// 3. 提取 Token
|
||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||
if tokenString == "" {
|
||
writeError(w, 401, "Token cannot be empty")
|
||
return
|
||
}
|
||
|
||
// 4. 检查 Redis 缓存(可选)
|
||
if m.Rds != nil {
|
||
cacheKey := fmt.Sprintf("token:%s", tokenString)
|
||
cachedUserID, err := m.Rds.Get(r.Context(), cacheKey).Result()
|
||
if err == nil && cachedUserID != "" {
|
||
// 从缓存中获取到用户ID
|
||
ctx := context.WithValue(r.Context(), UserIDKey, cachedUserID)
|
||
r = r.WithContext(ctx)
|
||
next(w, r)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 5. 调用 RPC 验证 Token
|
||
fmt.Printf("Validating token: %s\n", tokenString)
|
||
token, err := m.AppRpc.AuthToken(r.Context(), &app.AuthReq{Token: tokenString})
|
||
if err != nil {
|
||
fmt.Printf("Error validating token: %v\n", err)
|
||
// 根据错误类型返回不同的错误信息
|
||
var jwtErr *jwt.ValidationError
|
||
if errors.As(err, &jwtErr) {
|
||
switch {
|
||
case jwtErr.Errors&jwt.ValidationErrorExpired != 0:
|
||
writeError(w, 401, "Token has expired")
|
||
case jwtErr.Errors&jwt.ValidationErrorMalformed != 0:
|
||
writeError(w, 401, "Invalid token format")
|
||
case jwtErr.Errors&jwt.ValidationErrorSignatureInvalid != 0:
|
||
writeError(w, 401, "Invalid token signature")
|
||
case jwtErr.Errors&jwt.ValidationErrorNotValidYet != 0:
|
||
writeError(w, 401, "Token not valid yet")
|
||
default:
|
||
writeError(w, 401, "Invalid token")
|
||
}
|
||
} else {
|
||
// 网络错误或其他错误
|
||
writeError(w, 500, "Token validation service unavailable")
|
||
}
|
||
return
|
||
}
|
||
|
||
// 6. 获取用户ID
|
||
id := token.UserId
|
||
|
||
// 7. 缓存到 Redis(可选)
|
||
if m.Rds != nil {
|
||
cacheKey := fmt.Sprintf("token:%s", tokenString)
|
||
// 设置缓存,过期时间30分钟
|
||
m.Rds.Set(r.Context(), cacheKey, id, 30*time.Minute)
|
||
}
|
||
|
||
// 8. 设置到上下文
|
||
ctx := r.Context()
|
||
ctx = context.WithValue(ctx, UserIDKey, id)
|
||
ctx = context.WithValue(ctx, UserInfoKey, token)
|
||
|
||
// 9. 记录请求日志(可选)
|
||
fmt.Printf("[%s] %s - UserID: %d - Duration: %v\n",
|
||
time.Now().Format("2006-01-02 15:04:05"),
|
||
r.URL.Path,
|
||
id,
|
||
time.Since(startTime))
|
||
|
||
// 10. 继续处理请求
|
||
r = r.WithContext(ctx)
|
||
next(w, r)
|
||
}
|
||
}
|