// Package middleware 参数加解密中间件 package middleware import ( "bytes" "encoding/json" "errors" "io" "net/http" "strings" "server/common" "server/config" "github.com/gin-gonic/gin" ) // PayloadCryptoMiddleware 请求/响应参数加解密中间件 func PayloadCryptoMiddleware() gin.HandlerFunc { return func(c *gin.Context) { cfg := config.AppConfig.PayloadCrypto if !cfg.Enable { c.Next() return } path := c.Request.URL.Path if isPayloadCryptoWhitelist(path, cfg.Whitelist) { c.Next() return } if strings.TrimSpace(cfg.SecretKey) == "" { common.Warn("参数加密开启但secret_key为空,已跳过 Path=%s", path) c.Next() return } if cfg.Request.Enable { if err := maybeDecryptRequest(c, cfg); err != nil { common.Warn("请求解密失败: %v Path=%s", err, path) common.Error(c, http.StatusBadRequest, "请求解密失败") c.Abort() return } } if !cfg.Response.Enable { c.Next() return } originWriter := c.Writer writer := newCryptoResponseWriter(originWriter) c.Writer = writer c.Next() c.Writer = originWriter if err := writeEncryptedResponse(c, writer, cfg); err != nil { common.Warn("响应加密失败: %v Path=%s", err, path) writer.writePlain(originWriter) } } } func maybeDecryptRequest(c *gin.Context, cfg config.PayloadCryptoConfig) error { headerVal := strings.TrimSpace(c.GetHeader(cfg.HeaderKey)) needDecrypt := headerVal != "" && headerVal != "0" && strings.ToLower(headerVal) != "false" if cfg.Request.Required && !needDecrypt { return errRequiredEncrypted } if !needDecrypt { return nil } bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { return err } if len(bodyBytes) == 0 { return errEmptyBody } var payload common.EncryptedPayload if err := json.Unmarshal(bodyBytes, &payload); err != nil { return err } plaintext, err := common.DecryptPayload(payload, cfg.SecretKey) if err != nil { return err } c.Request.Body = io.NopCloser(bytes.NewReader(plaintext)) c.Request.ContentLength = int64(len(plaintext)) if c.Request.Header.Get("Content-Type") == "" { c.Request.Header.Set("Content-Type", "application/json") } return nil } func writeEncryptedResponse(c *gin.Context, writer *cryptoResponseWriter, cfg config.PayloadCryptoConfig) error { status := writer.Status() if status == http.StatusNoContent || status == http.StatusNotModified { writer.writePlain(c.Writer) return nil } contentType := c.Writer.Header().Get("Content-Type") if contentType != "" && !strings.HasPrefix(contentType, "application/json") { writer.writePlain(c.Writer) return nil } if writer.body.Len() == 0 { writer.writePlain(c.Writer) return nil } payload, err := common.EncryptPayload(writer.body.Bytes(), cfg.SecretKey) if err != nil { return err } respBytes, err := json.Marshal(payload) if err != nil { return err } c.Writer.Header().Set(cfg.HeaderKey, "1") c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8") c.Writer.Header().Del("Content-Length") c.Writer.WriteHeader(status) _, err = c.Writer.Write(respBytes) return err } func isPayloadCryptoWhitelist(path string, whitelist []string) bool { for _, white := range whitelist { if len(path) >= len(white) && path[:len(white)] == white { return true } } return false } type cryptoResponseWriter struct { gin.ResponseWriter body *bytes.Buffer status int } func newCryptoResponseWriter(w gin.ResponseWriter) *cryptoResponseWriter { return &cryptoResponseWriter{ ResponseWriter: w, body: &bytes.Buffer{}, status: http.StatusOK, } } func (w *cryptoResponseWriter) WriteHeader(code int) { if code > 0 { w.status = code } } func (w *cryptoResponseWriter) WriteHeaderNow() { if w.status == 0 { w.status = http.StatusOK } } func (w *cryptoResponseWriter) Write(data []byte) (int, error) { return w.body.Write(data) } func (w *cryptoResponseWriter) WriteString(s string) (int, error) { return w.body.WriteString(s) } func (w *cryptoResponseWriter) Status() int { if w.status == 0 { return http.StatusOK } return w.status } func (w *cryptoResponseWriter) Size() int { return w.body.Len() } func (w *cryptoResponseWriter) Written() bool { return w.body.Len() > 0 } func (w *cryptoResponseWriter) Flush() {} func (w *cryptoResponseWriter) writePlain(origin gin.ResponseWriter) { origin.Header().Del("Content-Length") origin.WriteHeader(w.Status()) if w.body.Len() == 0 { return } _, _ = origin.Write(w.body.Bytes()) } var ( errRequiredEncrypted = errors.New("request body must be encrypted") errEmptyBody = errors.New("request body is empty") )