143 lines
3.0 KiB
Go
143 lines
3.0 KiB
Go
// Package middleware 限流中间件
|
||
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"server/common"
|
||
"server/config"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/redis/go-redis/v9"
|
||
)
|
||
|
||
// RateLimitMiddleware 限流中间件
|
||
// 基于 Redis 实现,支持按用户ID或IP限流
|
||
// 不同接口可配置不同的限流规则
|
||
func RateLimitMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
cfg := config.AppConfig.RateLimit
|
||
|
||
// 未启用则跳过
|
||
if !cfg.Enable {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 白名单路径跳过
|
||
path := c.Request.URL.Path
|
||
if isRateLimitWhitelist(path) {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 获取限流规则
|
||
rule := getRule(path, cfg)
|
||
|
||
// 获取限流key (优先用户ID,否则用IP)
|
||
key := getRateLimitKey(c, path)
|
||
|
||
// 检查是否超过限制
|
||
if !checkRateLimit(key, rule) {
|
||
common.Warn("请求过于频繁: Key=%s Path=%s", key, path)
|
||
c.JSON(429, map[string]interface{}{
|
||
"code": 429,
|
||
"message": "操作过快,请稍后再试",
|
||
"data": nil,
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// getRule 获取路径对应的限流规则
|
||
func getRule(path string, cfg config.RateLimitConfig) config.RateLimitRule {
|
||
// 精确匹配
|
||
if rule, ok := cfg.Rules[path]; ok {
|
||
return rule
|
||
}
|
||
|
||
// 前缀匹配
|
||
for rulePath, rule := range cfg.Rules {
|
||
if strings.HasPrefix(path, rulePath) {
|
||
return rule
|
||
}
|
||
}
|
||
|
||
// 返回默认规则
|
||
return cfg.Default
|
||
}
|
||
|
||
// getRateLimitKey 获取限流key
|
||
func getRateLimitKey(c *gin.Context, path string) string {
|
||
// 优先使用用户ID
|
||
if user := common.GetLoginUser(c); user != nil {
|
||
return fmt.Sprintf("ratelimit:%s:%s", user.ID, path)
|
||
}
|
||
// 否则使用IP
|
||
return fmt.Sprintf("ratelimit:%s:%s", c.ClientIP(), path)
|
||
}
|
||
|
||
// checkRateLimit 检查是否超过限流
|
||
// 使用 Redis 滑动窗口算法
|
||
func checkRateLimit(key string, rule config.RateLimitRule) bool {
|
||
ctx := context.Background()
|
||
rdb := config.RDB
|
||
|
||
now := time.Now().UnixMilli()
|
||
windowStart := now - int64(rule.Interval*1000)
|
||
|
||
// 使用 Redis 事务
|
||
pipe := rdb.Pipeline()
|
||
|
||
// 移除窗口外的记录
|
||
pipe.ZRemRangeByScore(ctx, key, "0", fmt.Sprintf("%d", windowStart))
|
||
|
||
// 获取当前窗口内的请求数
|
||
countCmd := pipe.ZCard(ctx, key)
|
||
|
||
// 添加当前请求
|
||
pipe.ZAdd(ctx, key, redis.Z{
|
||
Score: float64(now),
|
||
Member: fmt.Sprintf("%d", now),
|
||
})
|
||
|
||
// 设置过期时间
|
||
pipe.Expire(ctx, key, time.Duration(rule.Interval)*time.Second)
|
||
|
||
_, err := pipe.Exec(ctx)
|
||
if err != nil {
|
||
common.LogError("限流检查失败: %v", err)
|
||
return true // 出错时放行
|
||
}
|
||
|
||
count := countCmd.Val()
|
||
return count < int64(rule.MaxRequests)
|
||
}
|
||
|
||
// 限流白名单
|
||
var rateLimitWhitelist = []string{
|
||
"/swagger/",
|
||
"/api/auth/logout",
|
||
}
|
||
|
||
func isRateLimitWhitelist(path string) bool {
|
||
for _, white := range rateLimitWhitelist {
|
||
if strings.HasPrefix(path, white) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// AddRateLimitWhitelist 添加限流白名单
|
||
func AddRateLimitWhitelist(paths ...string) {
|
||
rateLimitWhitelist = append(rateLimitWhitelist, paths...)
|
||
}
|