package common import ( "context" "fmt" "runtime" "strconv" "strings" "sync" ) type contextKey string const ( // ContextTenantIDKey 用于在 Context 中存储租户ID ContextTenantIDKey contextKey = "tenantId" // DBSettingTenantID 用于在 GORM DB 上临时覆盖租户ID DBSettingTenantID = "tenant_id" // DBSettingTenantSkip 用于临时跳过租户过滤 DBSettingTenantSkip = "tenant_skip" ) var requestContextStore sync.Map // BindRequestContext 绑定当前请求上下文到当前协程,返回清理函数 func BindRequestContext(ctx context.Context) func() { gid := goroutineID() if gid == 0 { return func() {} } requestContextStore.Store(gid, ctx) return func() { requestContextStore.Delete(gid) } } // GetRequestContext 获取当前协程绑定的请求上下文 func GetRequestContext() context.Context { gid := goroutineID() if gid == 0 { return nil } if v, ok := requestContextStore.Load(gid); ok { if ctx, ok := v.(context.Context); ok { return ctx } } return nil } // TenantIDFromContext 从 Context 中获取租户ID func TenantIDFromContext(ctx context.Context) string { if ctx == nil { return "" } value := ctx.Value(ContextTenantIDKey) if value == nil { return "" } return strings.TrimSpace(toString(value)) } // CurrentTenantID 获取当前协程绑定的租户ID func CurrentTenantID() string { return TenantIDFromContext(GetRequestContext()) } func goroutineID() uint64 { var buf [64]byte n := runtime.Stack(buf[:], false) if n <= 0 { return 0 } fields := strings.Fields(string(buf[:n])) if len(fields) < 2 { return 0 } id, err := strconv.ParseUint(fields[1], 10, 64) if err != nil { return 0 } return id } func toString(value interface{}) string { switch v := value.(type) { case string: return v case []byte: return string(v) default: return fmt.Sprint(value) } }