wz-golang-server/server/common/tenant_context.go

95 lines
1.8 KiB
Go

package common
import (
"context"
"fmt"
"runtime"
"strconv"
"strings"
"sync"
)
type contextKey string
const (
// ContextTenantIDKey 用于在 Context 中存储租户ID
ContextTenantIDKey contextKey = "tenantId"
// DBSettingTenantID 用于在 GORM DB 上临时覆盖租户ID
DBSettingTenantID = "tenant_id"
// DBSettingTenantSkip 用于临时跳过租户过滤
DBSettingTenantSkip = "tenant_skip"
)
var requestContextStore sync.Map
// BindRequestContext 绑定当前请求上下文到当前协程,返回清理函数
func BindRequestContext(ctx context.Context) func() {
gid := goroutineID()
if gid == 0 {
return func() {}
}
requestContextStore.Store(gid, ctx)
return func() {
requestContextStore.Delete(gid)
}
}
// GetRequestContext 获取当前协程绑定的请求上下文
func GetRequestContext() context.Context {
gid := goroutineID()
if gid == 0 {
return nil
}
if v, ok := requestContextStore.Load(gid); ok {
if ctx, ok := v.(context.Context); ok {
return ctx
}
}
return nil
}
// TenantIDFromContext 从 Context 中获取租户ID
func TenantIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
value := ctx.Value(ContextTenantIDKey)
if value == nil {
return ""
}
return strings.TrimSpace(toString(value))
}
// CurrentTenantID 获取当前协程绑定的租户ID
func CurrentTenantID() string {
return TenantIDFromContext(GetRequestContext())
}
func goroutineID() uint64 {
var buf [64]byte
n := runtime.Stack(buf[:], false)
if n <= 0 {
return 0
}
fields := strings.Fields(string(buf[:n]))
if len(fields) < 2 {
return 0
}
id, err := strconv.ParseUint(fields[1], 10, 64)
if err != nil {
return 0
}
return id
}
func toString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprint(value)
}
}