95 lines
1.8 KiB
Go
95 lines
1.8 KiB
Go
package common
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
// ContextTenantIDKey 用于在 Context 中存储租户ID
|
|
ContextTenantIDKey contextKey = "tenantId"
|
|
// DBSettingTenantID 用于在 GORM DB 上临时覆盖租户ID
|
|
DBSettingTenantID = "tenant_id"
|
|
// DBSettingTenantSkip 用于临时跳过租户过滤
|
|
DBSettingTenantSkip = "tenant_skip"
|
|
)
|
|
|
|
var requestContextStore sync.Map
|
|
|
|
// BindRequestContext 绑定当前请求上下文到当前协程,返回清理函数
|
|
func BindRequestContext(ctx context.Context) func() {
|
|
gid := goroutineID()
|
|
if gid == 0 {
|
|
return func() {}
|
|
}
|
|
requestContextStore.Store(gid, ctx)
|
|
return func() {
|
|
requestContextStore.Delete(gid)
|
|
}
|
|
}
|
|
|
|
// GetRequestContext 获取当前协程绑定的请求上下文
|
|
func GetRequestContext() context.Context {
|
|
gid := goroutineID()
|
|
if gid == 0 {
|
|
return nil
|
|
}
|
|
if v, ok := requestContextStore.Load(gid); ok {
|
|
if ctx, ok := v.(context.Context); ok {
|
|
return ctx
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// TenantIDFromContext 从 Context 中获取租户ID
|
|
func TenantIDFromContext(ctx context.Context) string {
|
|
if ctx == nil {
|
|
return ""
|
|
}
|
|
value := ctx.Value(ContextTenantIDKey)
|
|
if value == nil {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(toString(value))
|
|
}
|
|
|
|
// CurrentTenantID 获取当前协程绑定的租户ID
|
|
func CurrentTenantID() string {
|
|
return TenantIDFromContext(GetRequestContext())
|
|
}
|
|
|
|
func goroutineID() uint64 {
|
|
var buf [64]byte
|
|
n := runtime.Stack(buf[:], false)
|
|
if n <= 0 {
|
|
return 0
|
|
}
|
|
fields := strings.Fields(string(buf[:n]))
|
|
if len(fields) < 2 {
|
|
return 0
|
|
}
|
|
id, err := strconv.ParseUint(fields[1], 10, 64)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return id
|
|
}
|
|
|
|
func toString(value interface{}) string {
|
|
switch v := value.(type) {
|
|
case string:
|
|
return v
|
|
case []byte:
|
|
return string(v)
|
|
default:
|
|
return fmt.Sprint(value)
|
|
}
|
|
}
|