214 lines
4.6 KiB
Go
214 lines
4.6 KiB
Go
// Package middleware 参数加解密中间件
|
||
package middleware
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"errors"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"server/common"
|
||
"server/config"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// PayloadCryptoMiddleware 请求/响应参数加解密中间件
|
||
func PayloadCryptoMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
cfg := config.AppConfig.PayloadCrypto
|
||
if !cfg.Enable {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
path := c.Request.URL.Path
|
||
if isPayloadCryptoWhitelist(path, cfg.Whitelist) {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
if strings.TrimSpace(cfg.SecretKey) == "" {
|
||
common.Warn("参数加密开启但secret_key为空,已跳过 Path=%s", path)
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
if cfg.Request.Enable {
|
||
if err := maybeDecryptRequest(c, cfg); err != nil {
|
||
common.Warn("请求解密失败: %v Path=%s", err, path)
|
||
common.Error(c, http.StatusBadRequest, "请求解密失败")
|
||
c.Abort()
|
||
return
|
||
}
|
||
}
|
||
|
||
if !cfg.Response.Enable {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
originWriter := c.Writer
|
||
writer := newCryptoResponseWriter(originWriter)
|
||
c.Writer = writer
|
||
|
||
c.Next()
|
||
|
||
c.Writer = originWriter
|
||
if err := writeEncryptedResponse(c, writer, cfg); err != nil {
|
||
common.Warn("响应加密失败: %v Path=%s", err, path)
|
||
writer.writePlain(originWriter)
|
||
}
|
||
}
|
||
}
|
||
|
||
func maybeDecryptRequest(c *gin.Context, cfg config.PayloadCryptoConfig) error {
|
||
headerVal := strings.TrimSpace(c.GetHeader(cfg.HeaderKey))
|
||
needDecrypt := headerVal != "" && headerVal != "0" && strings.ToLower(headerVal) != "false"
|
||
if cfg.Request.Required && !needDecrypt {
|
||
return errRequiredEncrypted
|
||
}
|
||
|
||
if !needDecrypt {
|
||
return nil
|
||
}
|
||
|
||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if len(bodyBytes) == 0 {
|
||
return errEmptyBody
|
||
}
|
||
|
||
var payload common.EncryptedPayload
|
||
if err := json.Unmarshal(bodyBytes, &payload); err != nil {
|
||
return err
|
||
}
|
||
|
||
plaintext, err := common.DecryptPayload(payload, cfg.SecretKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
c.Request.Body = io.NopCloser(bytes.NewReader(plaintext))
|
||
c.Request.ContentLength = int64(len(plaintext))
|
||
if c.Request.Header.Get("Content-Type") == "" {
|
||
c.Request.Header.Set("Content-Type", "application/json")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func writeEncryptedResponse(c *gin.Context, writer *cryptoResponseWriter, cfg config.PayloadCryptoConfig) error {
|
||
status := writer.Status()
|
||
if status == http.StatusNoContent || status == http.StatusNotModified {
|
||
writer.writePlain(c.Writer)
|
||
return nil
|
||
}
|
||
|
||
contentType := c.Writer.Header().Get("Content-Type")
|
||
if contentType != "" && !strings.HasPrefix(contentType, "application/json") {
|
||
writer.writePlain(c.Writer)
|
||
return nil
|
||
}
|
||
|
||
if writer.body.Len() == 0 {
|
||
writer.writePlain(c.Writer)
|
||
return nil
|
||
}
|
||
|
||
payload, err := common.EncryptPayload(writer.body.Bytes(), cfg.SecretKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
respBytes, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
c.Writer.Header().Set(cfg.HeaderKey, "1")
|
||
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||
c.Writer.Header().Del("Content-Length")
|
||
c.Writer.WriteHeader(status)
|
||
_, err = c.Writer.Write(respBytes)
|
||
return err
|
||
}
|
||
|
||
func isPayloadCryptoWhitelist(path string, whitelist []string) bool {
|
||
for _, white := range whitelist {
|
||
if len(path) >= len(white) && path[:len(white)] == white {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
type cryptoResponseWriter struct {
|
||
gin.ResponseWriter
|
||
body *bytes.Buffer
|
||
status int
|
||
}
|
||
|
||
func newCryptoResponseWriter(w gin.ResponseWriter) *cryptoResponseWriter {
|
||
return &cryptoResponseWriter{
|
||
ResponseWriter: w,
|
||
body: &bytes.Buffer{},
|
||
status: http.StatusOK,
|
||
}
|
||
}
|
||
|
||
func (w *cryptoResponseWriter) WriteHeader(code int) {
|
||
if code > 0 {
|
||
w.status = code
|
||
}
|
||
}
|
||
|
||
func (w *cryptoResponseWriter) WriteHeaderNow() {
|
||
if w.status == 0 {
|
||
w.status = http.StatusOK
|
||
}
|
||
}
|
||
|
||
func (w *cryptoResponseWriter) Write(data []byte) (int, error) {
|
||
return w.body.Write(data)
|
||
}
|
||
|
||
func (w *cryptoResponseWriter) WriteString(s string) (int, error) {
|
||
return w.body.WriteString(s)
|
||
}
|
||
|
||
func (w *cryptoResponseWriter) Status() int {
|
||
if w.status == 0 {
|
||
return http.StatusOK
|
||
}
|
||
return w.status
|
||
}
|
||
|
||
func (w *cryptoResponseWriter) Size() int {
|
||
return w.body.Len()
|
||
}
|
||
|
||
func (w *cryptoResponseWriter) Written() bool {
|
||
return w.body.Len() > 0
|
||
}
|
||
|
||
func (w *cryptoResponseWriter) Flush() {}
|
||
|
||
func (w *cryptoResponseWriter) writePlain(origin gin.ResponseWriter) {
|
||
origin.Header().Del("Content-Length")
|
||
origin.WriteHeader(w.Status())
|
||
if w.body.Len() == 0 {
|
||
return
|
||
}
|
||
_, _ = origin.Write(w.body.Bytes())
|
||
}
|
||
|
||
var (
|
||
errRequiredEncrypted = errors.New("request body must be encrypted")
|
||
errEmptyBody = errors.New("request body is empty")
|
||
)
|