// Package middleware 限流中间件 package middleware import ( "context" "fmt" "strings" "time" "server/common" "server/config" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" ) // RateLimitMiddleware 限流中间件 // 基于 Redis 实现,支持按用户ID或IP限流 // 不同接口可配置不同的限流规则 func RateLimitMiddleware() gin.HandlerFunc { return func(c *gin.Context) { cfg := config.AppConfig.RateLimit // 未启用则跳过 if !cfg.Enable { c.Next() return } // 白名单路径跳过 path := c.Request.URL.Path if isRateLimitWhitelist(path) { c.Next() return } // 获取限流规则 rule := getRule(path, cfg) // 获取限流key (优先用户ID,否则用IP) key := getRateLimitKey(c, path) // 检查是否超过限制 if !checkRateLimit(key, rule) { common.Warn("请求过于频繁: Key=%s Path=%s", key, path) c.JSON(429, map[string]interface{}{ "code": 429, "message": "操作过快,请稍后再试", "data": nil, }) c.Abort() return } c.Next() } } // getRule 获取路径对应的限流规则 func getRule(path string, cfg config.RateLimitConfig) config.RateLimitRule { // 精确匹配 if rule, ok := cfg.Rules[path]; ok { return rule } // 前缀匹配 for rulePath, rule := range cfg.Rules { if strings.HasPrefix(path, rulePath) { return rule } } // 返回默认规则 return cfg.Default } // getRateLimitKey 获取限流key func getRateLimitKey(c *gin.Context, path string) string { // 优先使用用户ID if user := common.GetLoginUser(c); user != nil { return fmt.Sprintf("ratelimit:%s:%s", user.ID, path) } // 否则使用IP return fmt.Sprintf("ratelimit:%s:%s", c.ClientIP(), path) } // checkRateLimit 检查是否超过限流 // 使用 Redis 滑动窗口算法 func checkRateLimit(key string, rule config.RateLimitRule) bool { ctx := context.Background() rdb := config.RDB now := time.Now().UnixMilli() windowStart := now - int64(rule.Interval*1000) // 使用 Redis 事务 pipe := rdb.Pipeline() // 移除窗口外的记录 pipe.ZRemRangeByScore(ctx, key, "0", fmt.Sprintf("%d", windowStart)) // 获取当前窗口内的请求数 countCmd := pipe.ZCard(ctx, key) // 添加当前请求 pipe.ZAdd(ctx, key, redis.Z{ Score: float64(now), Member: fmt.Sprintf("%d", now), }) // 设置过期时间 pipe.Expire(ctx, key, time.Duration(rule.Interval)*time.Second) _, err := pipe.Exec(ctx) if err != nil { common.LogError("限流检查失败: %v", err) return true // 出错时放行 } count := countCmd.Val() return count < int64(rule.MaxRequests) } // 限流白名单 var rateLimitWhitelist = []string{ "/swagger/", "/api/auth/logout", } func isRateLimitWhitelist(path string) bool { for _, white := range rateLimitWhitelist { if strings.HasPrefix(path, white) { return true } } return false } // AddRateLimitWhitelist 添加限流白名单 func AddRateLimitWhitelist(paths ...string) { rateLimitWhitelist = append(rateLimitWhitelist, paths...) }