wz-golang-server/server/common/tenant_plugin.go

185 lines
3.9 KiB
Go

package common
import (
"fmt"
"strings"
"server/config"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// InitTenantPlugin 初始化租户插件
func InitTenantPlugin(db *gorm.DB) {
if db == nil {
return
}
plugin := NewTenantPlugin(config.AppConfig.Tenant)
if err := db.Use(plugin); err != nil {
LogError("租户插件初始化失败: %v", err)
return
}
Info("租户插件初始化完成 (enable=%v)", plugin.enabled)
}
// TenantPlugin GORM 多租户插件
type TenantPlugin struct {
enabled bool
column string
tableSet map[string]struct{}
applyAll bool
}
// NewTenantPlugin 创建插件实例
func NewTenantPlugin(cfg config.TenantConfig) *TenantPlugin {
column := strings.TrimSpace(cfg.Column)
if column == "" {
column = "tenant_id"
}
tableSet := make(map[string]struct{})
applyAll := false
for _, table := range cfg.Tables {
name := normalizeTableName(table)
if name == "" {
continue
}
if name == "*" {
applyAll = true
continue
}
tableSet[name] = struct{}{}
}
return &TenantPlugin{
enabled: cfg.Enable,
column: column,
tableSet: tableSet,
applyAll: applyAll,
}
}
// Name 插件名
func (p *TenantPlugin) Name() string {
return "tenant_plugin"
}
// Initialize 注册回调
func (p *TenantPlugin) Initialize(db *gorm.DB) error {
db.Callback().Query().Before("gorm:query").Register("tenant:query", p.before)
db.Callback().Row().Before("gorm:row").Register("tenant:row", p.before)
db.Callback().Raw().Before("gorm:raw").Register("tenant:raw", p.before)
db.Callback().Update().Before("gorm:update").Register("tenant:update", p.before)
db.Callback().Delete().Before("gorm:delete").Register("tenant:delete", p.before)
return nil
}
func (p *TenantPlugin) before(db *gorm.DB) {
if !p.enabled || db == nil || db.Statement == nil {
return
}
if skip, ok := db.Get(DBSettingTenantSkip); ok {
if v, ok := skip.(bool); ok && v {
return
}
}
tenantID, ok := p.tenantID(db)
if !ok {
return
}
table := p.tableName(db)
if table == "" || !p.matchTable(table) {
return
}
db.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Name: p.column},
Value: tenantID,
},
},
})
}
func (p *TenantPlugin) tenantID(db *gorm.DB) (string, bool) {
if db == nil {
return "", false
}
if value, ok := db.Get(DBSettingTenantID); ok {
id := strings.TrimSpace(fmt.Sprint(value))
if id != "" {
return id, true
}
}
if db.Statement != nil && db.Statement.Context != nil {
id := TenantIDFromContext(db.Statement.Context)
if id != "" {
return id, true
}
}
return "", false
}
func (p *TenantPlugin) tableName(db *gorm.DB) string {
if db == nil || db.Statement == nil {
return ""
}
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
return normalizeTableName(db.Statement.Schema.Table)
}
if db.Statement.Table != "" {
return normalizeTableName(db.Statement.Table)
}
if db.Statement.SQL.String() != "" {
return normalizeTableName(tableFromSQL(db.Statement.SQL.String()))
}
return ""
}
func (p *TenantPlugin) matchTable(table string) bool {
if p.applyAll {
return true
}
_, ok := p.tableSet[normalizeTableName(table)]
return ok
}
func normalizeTableName(name string) string {
name = strings.TrimSpace(name)
if name == "" {
return ""
}
fields := strings.Fields(name)
if len(fields) > 0 {
name = fields[0]
}
name = strings.Trim(name, "`\"")
if idx := strings.LastIndex(name, "."); idx >= 0 {
name = name[idx+1:]
}
return strings.ToLower(strings.TrimSpace(name))
}
func tableFromSQL(sql string) string {
lower := strings.ToLower(sql)
idx := strings.Index(lower, " from ")
if idx == -1 {
return ""
}
rest := strings.TrimSpace(sql[idx+6:])
if rest == "" {
return ""
}
fields := strings.Fields(rest)
if len(fields) == 0 {
return ""
}
table := strings.Trim(fields[0], "`\"")
if strings.HasPrefix(table, "(") {
return ""
}
return table
}