package common import ( "fmt" "strings" "server/config" "gorm.io/gorm" "gorm.io/gorm/clause" ) // InitTenantPlugin 初始化租户插件 func InitTenantPlugin(db *gorm.DB) { if db == nil { return } plugin := NewTenantPlugin(config.AppConfig.Tenant) if err := db.Use(plugin); err != nil { LogError("租户插件初始化失败: %v", err) return } Info("租户插件初始化完成 (enable=%v)", plugin.enabled) } // TenantPlugin GORM 多租户插件 type TenantPlugin struct { enabled bool column string tableSet map[string]struct{} applyAll bool } // NewTenantPlugin 创建插件实例 func NewTenantPlugin(cfg config.TenantConfig) *TenantPlugin { column := strings.TrimSpace(cfg.Column) if column == "" { column = "tenant_id" } tableSet := make(map[string]struct{}) applyAll := false for _, table := range cfg.Tables { name := normalizeTableName(table) if name == "" { continue } if name == "*" { applyAll = true continue } tableSet[name] = struct{}{} } return &TenantPlugin{ enabled: cfg.Enable, column: column, tableSet: tableSet, applyAll: applyAll, } } // Name 插件名 func (p *TenantPlugin) Name() string { return "tenant_plugin" } // Initialize 注册回调 func (p *TenantPlugin) Initialize(db *gorm.DB) error { db.Callback().Query().Before("gorm:query").Register("tenant:query", p.before) db.Callback().Row().Before("gorm:row").Register("tenant:row", p.before) db.Callback().Raw().Before("gorm:raw").Register("tenant:raw", p.before) db.Callback().Update().Before("gorm:update").Register("tenant:update", p.before) db.Callback().Delete().Before("gorm:delete").Register("tenant:delete", p.before) return nil } func (p *TenantPlugin) before(db *gorm.DB) { if !p.enabled || db == nil || db.Statement == nil { return } if skip, ok := db.Get(DBSettingTenantSkip); ok { if v, ok := skip.(bool); ok && v { return } } tenantID, ok := p.tenantID(db) if !ok { return } table := p.tableName(db) if table == "" || !p.matchTable(table) { return } db.Statement.AddClause(clause.Where{ Exprs: []clause.Expression{ clause.Eq{ Column: clause.Column{Name: p.column}, Value: tenantID, }, }, }) } func (p *TenantPlugin) tenantID(db *gorm.DB) (string, bool) { if db == nil { return "", false } if value, ok := db.Get(DBSettingTenantID); ok { id := strings.TrimSpace(fmt.Sprint(value)) if id != "" { return id, true } } if db.Statement != nil && db.Statement.Context != nil { id := TenantIDFromContext(db.Statement.Context) if id != "" { return id, true } } return "", false } func (p *TenantPlugin) tableName(db *gorm.DB) string { if db == nil || db.Statement == nil { return "" } if db.Statement.Schema != nil && db.Statement.Schema.Table != "" { return normalizeTableName(db.Statement.Schema.Table) } if db.Statement.Table != "" { return normalizeTableName(db.Statement.Table) } if db.Statement.SQL.String() != "" { return normalizeTableName(tableFromSQL(db.Statement.SQL.String())) } return "" } func (p *TenantPlugin) matchTable(table string) bool { if p.applyAll { return true } _, ok := p.tableSet[normalizeTableName(table)] return ok } func normalizeTableName(name string) string { name = strings.TrimSpace(name) if name == "" { return "" } fields := strings.Fields(name) if len(fields) > 0 { name = fields[0] } name = strings.Trim(name, "`\"") if idx := strings.LastIndex(name, "."); idx >= 0 { name = name[idx+1:] } return strings.ToLower(strings.TrimSpace(name)) } func tableFromSQL(sql string) string { lower := strings.ToLower(sql) idx := strings.Index(lower, " from ") if idx == -1 { return "" } rest := strings.TrimSpace(sql[idx+6:]) if rest == "" { return "" } fields := strings.Fields(rest) if len(fields) == 0 { return "" } table := strings.Trim(fields[0], "`\"") if strings.HasPrefix(table, "(") { return "" } return table }