mingyang-admin-iot-app/api/internal/middleware/authority_middleware.go

145 lines
3.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/golang-jwt/jwt/v4"
"github.com/redis/go-redis/v9"
"mingyang-admin-app-api/internal/types"
"mingyang-admin-app-rpc/appclient"
"mingyang-admin-app-rpc/types/app"
"net/http"
"strings"
"time"
)
// 定义上下文键类型
type contextKey string
const (
UserIDKey contextKey = "user_id"
UserInfoKey contextKey = "user_info"
)
type AuthorityMiddleware struct {
Rds redis.UniversalClient
AppRpc appclient.App
}
// NewAuthorityMiddleware 创建认证中间件
func NewAuthorityMiddleware(appRpc appclient.App, rds redis.UniversalClient) *AuthorityMiddleware {
return &AuthorityMiddleware{
AppRpc: appRpc,
Rds: rds,
}
}
// writeResult 写入统一格式的响应
func writeResult(w http.ResponseWriter, statusCode int, result *types.BaseDataInfo) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(result)
}
// writeError 写入错误响应(简化版)
func writeError(w http.ResponseWriter, statusCode int, message string) {
writeResult(w, statusCode, &types.BaseDataInfo{
Msg: message,
Code: 500,
})
}
// Handle 中间件处理函数
func (m *AuthorityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
// 1. 从 Authorization Header 中提取 Bearer Token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeError(w, 401, "Authorization header is required")
return
}
// 2. 验证 Bearer Token 格式
if !strings.HasPrefix(authHeader, "Bearer ") {
writeError(w, 401, "Invalid authorization format, must be 'Bearer <token>'")
return
}
// 3. 提取 Token
tokenString := strings.TrimSpace(strings.TrimPrefix(authHeader, "Bearer"))
if tokenString == "" {
writeError(w, 401, "Token cannot be empty")
return
}
// 4. 检查 Redis 缓存(可选)
if m.Rds != nil {
cacheKey := fmt.Sprintf("token:%s", tokenString)
cachedUserID, err := m.Rds.Get(r.Context(), cacheKey).Result()
if err == nil && cachedUserID != "" {
// 从缓存中获取到用户ID
ctx := context.WithValue(r.Context(), UserIDKey, cachedUserID)
r = r.WithContext(ctx)
next(w, r)
return
}
}
// 5. 调用 RPC 验证 Token
fmt.Printf("Validating token: %s\n", tokenString)
token, err := m.AppRpc.AuthToken(r.Context(), &app.AuthReq{Token: tokenString})
if err != nil {
fmt.Printf("Error validating token: %v\n", err)
// 根据错误类型返回不同的错误信息
var jwtErr *jwt.ValidationError
if errors.As(err, &jwtErr) {
switch {
case jwtErr.Errors&jwt.ValidationErrorExpired != 0:
writeError(w, 401, "Token has expired")
case jwtErr.Errors&jwt.ValidationErrorMalformed != 0:
writeError(w, 401, "Invalid token format")
case jwtErr.Errors&jwt.ValidationErrorSignatureInvalid != 0:
writeError(w, 401, "Invalid token signature")
case jwtErr.Errors&jwt.ValidationErrorNotValidYet != 0:
writeError(w, 401, "Token not valid yet")
default:
writeError(w, 401, "Invalid token")
}
} else {
// 网络错误或其他错误
writeError(w, 500, "Token validation service unavailable")
}
return
}
// 6. 获取用户ID
id := token.UserId
// 7. 缓存到 Redis可选
if m.Rds != nil {
cacheKey := fmt.Sprintf("token:%s", tokenString)
// 设置缓存过期时间30分钟
m.Rds.Set(r.Context(), cacheKey, id, 30*time.Minute)
}
// 8. 设置到上下文
ctx := r.Context()
ctx = context.WithValue(ctx, UserIDKey, id)
ctx = context.WithValue(ctx, UserInfoKey, token)
// 9. 记录请求日志(可选)
fmt.Printf("[%s] %s - UserID: %d - Duration: %v\n",
time.Now().Format("2006-01-02 15:04:05"),
r.URL.Path,
id,
time.Since(startTime))
// 10. 继续处理请求
r = r.WithContext(ctx)
next(w, r)
}
}