golang-yitisheng-server/server/middleware/ratelimit.go

143 lines
3.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package middleware 限流中间件
package middleware
import (
"context"
"fmt"
"strings"
"time"
"server/common"
"server/config"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
)
// RateLimitMiddleware 限流中间件
// 基于 Redis 实现支持按用户ID或IP限流
// 不同接口可配置不同的限流规则
func RateLimitMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
cfg := config.AppConfig.RateLimit
// 未启用则跳过
if !cfg.Enable {
c.Next()
return
}
// 白名单路径跳过
path := c.Request.URL.Path
if isRateLimitWhitelist(path) {
c.Next()
return
}
// 获取限流规则
rule := getRule(path, cfg)
// 获取限流key (优先用户ID否则用IP)
key := getRateLimitKey(c, path)
// 检查是否超过限制
if !checkRateLimit(key, rule) {
common.Warn("请求过于频繁: Key=%s Path=%s", key, path)
c.JSON(429, map[string]interface{}{
"code": 429,
"message": "操作过快,请稍后再试",
"data": nil,
})
c.Abort()
return
}
c.Next()
}
}
// getRule 获取路径对应的限流规则
func getRule(path string, cfg config.RateLimitConfig) config.RateLimitRule {
// 精确匹配
if rule, ok := cfg.Rules[path]; ok {
return rule
}
// 前缀匹配
for rulePath, rule := range cfg.Rules {
if strings.HasPrefix(path, rulePath) {
return rule
}
}
// 返回默认规则
return cfg.Default
}
// getRateLimitKey 获取限流key
func getRateLimitKey(c *gin.Context, path string) string {
// 优先使用用户ID
if user := common.GetLoginUser(c); user != nil {
return fmt.Sprintf("ratelimit:%s:%s", user.ID, path)
}
// 否则使用IP
return fmt.Sprintf("ratelimit:%s:%s", c.ClientIP(), path)
}
// checkRateLimit 检查是否超过限流
// 使用 Redis 滑动窗口算法
func checkRateLimit(key string, rule config.RateLimitRule) bool {
ctx := context.Background()
rdb := config.RDB
now := time.Now().UnixMilli()
windowStart := now - int64(rule.Interval*1000)
// 使用 Redis 事务
pipe := rdb.Pipeline()
// 移除窗口外的记录
pipe.ZRemRangeByScore(ctx, key, "0", fmt.Sprintf("%d", windowStart))
// 获取当前窗口内的请求数
countCmd := pipe.ZCard(ctx, key)
// 添加当前请求
pipe.ZAdd(ctx, key, redis.Z{
Score: float64(now),
Member: fmt.Sprintf("%d", now),
})
// 设置过期时间
pipe.Expire(ctx, key, time.Duration(rule.Interval)*time.Second)
_, err := pipe.Exec(ctx)
if err != nil {
common.LogError("限流检查失败: %v", err)
return true // 出错时放行
}
count := countCmd.Val()
return count < int64(rule.MaxRequests)
}
// 限流白名单
var rateLimitWhitelist = []string{
"/swagger/",
"/api/auth/logout",
}
func isRateLimitWhitelist(path string) bool {
for _, white := range rateLimitWhitelist {
if strings.HasPrefix(path, white) {
return true
}
}
return false
}
// AddRateLimitWhitelist 添加限流白名单
func AddRateLimitWhitelist(paths ...string) {
rateLimitWhitelist = append(rateLimitWhitelist, paths...)
}