wz-golang-server/server/middleware/payload_crypto.go

214 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package middleware 参数加解密中间件
package middleware
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"strings"
"server/common"
"server/config"
"github.com/gin-gonic/gin"
)
// PayloadCryptoMiddleware 请求/响应参数加解密中间件
func PayloadCryptoMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
cfg := config.AppConfig.PayloadCrypto
if !cfg.Enable {
c.Next()
return
}
path := c.Request.URL.Path
if isPayloadCryptoWhitelist(path, cfg.Whitelist) {
c.Next()
return
}
if strings.TrimSpace(cfg.SecretKey) == "" {
common.Warn("参数加密开启但secret_key为空已跳过 Path=%s", path)
c.Next()
return
}
if cfg.Request.Enable {
if err := maybeDecryptRequest(c, cfg); err != nil {
common.Warn("请求解密失败: %v Path=%s", err, path)
common.Error(c, http.StatusBadRequest, "请求解密失败")
c.Abort()
return
}
}
if !cfg.Response.Enable {
c.Next()
return
}
originWriter := c.Writer
writer := newCryptoResponseWriter(originWriter)
c.Writer = writer
c.Next()
c.Writer = originWriter
if err := writeEncryptedResponse(c, writer, cfg); err != nil {
common.Warn("响应加密失败: %v Path=%s", err, path)
writer.writePlain(originWriter)
}
}
}
func maybeDecryptRequest(c *gin.Context, cfg config.PayloadCryptoConfig) error {
headerVal := strings.TrimSpace(c.GetHeader(cfg.HeaderKey))
needDecrypt := headerVal != "" && headerVal != "0" && strings.ToLower(headerVal) != "false"
if cfg.Request.Required && !needDecrypt {
return errRequiredEncrypted
}
if !needDecrypt {
return nil
}
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return err
}
if len(bodyBytes) == 0 {
return errEmptyBody
}
var payload common.EncryptedPayload
if err := json.Unmarshal(bodyBytes, &payload); err != nil {
return err
}
plaintext, err := common.DecryptPayload(payload, cfg.SecretKey)
if err != nil {
return err
}
c.Request.Body = io.NopCloser(bytes.NewReader(plaintext))
c.Request.ContentLength = int64(len(plaintext))
if c.Request.Header.Get("Content-Type") == "" {
c.Request.Header.Set("Content-Type", "application/json")
}
return nil
}
func writeEncryptedResponse(c *gin.Context, writer *cryptoResponseWriter, cfg config.PayloadCryptoConfig) error {
status := writer.Status()
if status == http.StatusNoContent || status == http.StatusNotModified {
writer.writePlain(c.Writer)
return nil
}
contentType := c.Writer.Header().Get("Content-Type")
if contentType != "" && !strings.HasPrefix(contentType, "application/json") {
writer.writePlain(c.Writer)
return nil
}
if writer.body.Len() == 0 {
writer.writePlain(c.Writer)
return nil
}
payload, err := common.EncryptPayload(writer.body.Bytes(), cfg.SecretKey)
if err != nil {
return err
}
respBytes, err := json.Marshal(payload)
if err != nil {
return err
}
c.Writer.Header().Set(cfg.HeaderKey, "1")
c.Writer.Header().Set("Content-Type", "application/json; charset=utf-8")
c.Writer.Header().Del("Content-Length")
c.Writer.WriteHeader(status)
_, err = c.Writer.Write(respBytes)
return err
}
func isPayloadCryptoWhitelist(path string, whitelist []string) bool {
for _, white := range whitelist {
if len(path) >= len(white) && path[:len(white)] == white {
return true
}
}
return false
}
type cryptoResponseWriter struct {
gin.ResponseWriter
body *bytes.Buffer
status int
}
func newCryptoResponseWriter(w gin.ResponseWriter) *cryptoResponseWriter {
return &cryptoResponseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
status: http.StatusOK,
}
}
func (w *cryptoResponseWriter) WriteHeader(code int) {
if code > 0 {
w.status = code
}
}
func (w *cryptoResponseWriter) WriteHeaderNow() {
if w.status == 0 {
w.status = http.StatusOK
}
}
func (w *cryptoResponseWriter) Write(data []byte) (int, error) {
return w.body.Write(data)
}
func (w *cryptoResponseWriter) WriteString(s string) (int, error) {
return w.body.WriteString(s)
}
func (w *cryptoResponseWriter) Status() int {
if w.status == 0 {
return http.StatusOK
}
return w.status
}
func (w *cryptoResponseWriter) Size() int {
return w.body.Len()
}
func (w *cryptoResponseWriter) Written() bool {
return w.body.Len() > 0
}
func (w *cryptoResponseWriter) Flush() {}
func (w *cryptoResponseWriter) writePlain(origin gin.ResponseWriter) {
origin.Header().Del("Content-Length")
origin.WriteHeader(w.Status())
if w.body.Len() == 0 {
return
}
_, _ = origin.Write(w.body.Bytes())
}
var (
errRequiredEncrypted = errors.New("request body must be encrypted")
errEmptyBody = errors.New("request body is empty")
)