wz-golang-server/server/common/payload_crypto.go

107 lines
2.3 KiB
Go

// Package common 参数加解密工具
package common
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"io"
"strings"
)
// EncryptedPayload 加密载荷结构
type EncryptedPayload struct {
Nonce string `json:"nonce"`
Ciphertext string `json:"ciphertext"`
}
// EncryptPayload 加密响应内容 (AES-GCM + Base64)
func EncryptPayload(plaintext []byte, secret string) (EncryptedPayload, error) {
key, err := deriveAESKey(secret)
if err != nil {
return EncryptedPayload{}, err
}
block, err := aes.NewCipher(key)
if err != nil {
return EncryptedPayload{}, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return EncryptedPayload{}, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return EncryptedPayload{}, err
}
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
return EncryptedPayload{
Nonce: base64.StdEncoding.EncodeToString(nonce),
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
}, nil
}
// DecryptPayload 解密请求内容 (AES-GCM + Base64)
func DecryptPayload(payload EncryptedPayload, secret string) ([]byte, error) {
key, err := deriveAESKey(secret)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce, err := base64.StdEncoding.DecodeString(payload.Nonce)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, errors.New("invalid nonce size")
}
ciphertext, err := base64.StdEncoding.DecodeString(payload.Ciphertext)
if err != nil {
return nil, err
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
func deriveAESKey(secret string) ([]byte, error) {
secret = strings.TrimSpace(secret)
if secret == "" {
return nil, errors.New("secret is empty")
}
if decoded, err := base64.StdEncoding.DecodeString(secret); err == nil {
if len(decoded) == 16 || len(decoded) == 24 || len(decoded) == 32 {
return decoded, nil
}
}
if len(secret) == 16 || len(secret) == 24 || len(secret) == 32 {
return []byte(secret), nil
}
sum := sha256.Sum256([]byte(secret))
return sum[:], nil
}