185 lines
3.9 KiB
Go
185 lines
3.9 KiB
Go
package common
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"server/config"
|
|
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
// InitTenantPlugin 初始化租户插件
|
|
func InitTenantPlugin(db *gorm.DB) {
|
|
if db == nil {
|
|
return
|
|
}
|
|
plugin := NewTenantPlugin(config.AppConfig.Tenant)
|
|
if err := db.Use(plugin); err != nil {
|
|
LogError("租户插件初始化失败: %v", err)
|
|
return
|
|
}
|
|
Info("租户插件初始化完成 (enable=%v)", plugin.enabled)
|
|
}
|
|
|
|
// TenantPlugin GORM 多租户插件
|
|
type TenantPlugin struct {
|
|
enabled bool
|
|
column string
|
|
tableSet map[string]struct{}
|
|
applyAll bool
|
|
}
|
|
|
|
// NewTenantPlugin 创建插件实例
|
|
func NewTenantPlugin(cfg config.TenantConfig) *TenantPlugin {
|
|
column := strings.TrimSpace(cfg.Column)
|
|
if column == "" {
|
|
column = "tenant_id"
|
|
}
|
|
tableSet := make(map[string]struct{})
|
|
applyAll := false
|
|
for _, table := range cfg.Tables {
|
|
name := normalizeTableName(table)
|
|
if name == "" {
|
|
continue
|
|
}
|
|
if name == "*" {
|
|
applyAll = true
|
|
continue
|
|
}
|
|
tableSet[name] = struct{}{}
|
|
}
|
|
return &TenantPlugin{
|
|
enabled: cfg.Enable,
|
|
column: column,
|
|
tableSet: tableSet,
|
|
applyAll: applyAll,
|
|
}
|
|
}
|
|
|
|
// Name 插件名
|
|
func (p *TenantPlugin) Name() string {
|
|
return "tenant_plugin"
|
|
}
|
|
|
|
// Initialize 注册回调
|
|
func (p *TenantPlugin) Initialize(db *gorm.DB) error {
|
|
db.Callback().Query().Before("gorm:query").Register("tenant:query", p.before)
|
|
db.Callback().Row().Before("gorm:row").Register("tenant:row", p.before)
|
|
db.Callback().Raw().Before("gorm:raw").Register("tenant:raw", p.before)
|
|
db.Callback().Update().Before("gorm:update").Register("tenant:update", p.before)
|
|
db.Callback().Delete().Before("gorm:delete").Register("tenant:delete", p.before)
|
|
return nil
|
|
}
|
|
|
|
func (p *TenantPlugin) before(db *gorm.DB) {
|
|
if !p.enabled || db == nil || db.Statement == nil {
|
|
return
|
|
}
|
|
if skip, ok := db.Get(DBSettingTenantSkip); ok {
|
|
if v, ok := skip.(bool); ok && v {
|
|
return
|
|
}
|
|
}
|
|
|
|
tenantID, ok := p.tenantID(db)
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
table := p.tableName(db)
|
|
if table == "" || !p.matchTable(table) {
|
|
return
|
|
}
|
|
|
|
db.Statement.AddClause(clause.Where{
|
|
Exprs: []clause.Expression{
|
|
clause.Eq{
|
|
Column: clause.Column{Name: p.column},
|
|
Value: tenantID,
|
|
},
|
|
},
|
|
})
|
|
}
|
|
|
|
func (p *TenantPlugin) tenantID(db *gorm.DB) (string, bool) {
|
|
if db == nil {
|
|
return "", false
|
|
}
|
|
if value, ok := db.Get(DBSettingTenantID); ok {
|
|
id := strings.TrimSpace(fmt.Sprint(value))
|
|
if id != "" {
|
|
return id, true
|
|
}
|
|
}
|
|
if db.Statement != nil && db.Statement.Context != nil {
|
|
id := TenantIDFromContext(db.Statement.Context)
|
|
if id != "" {
|
|
return id, true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func (p *TenantPlugin) tableName(db *gorm.DB) string {
|
|
if db == nil || db.Statement == nil {
|
|
return ""
|
|
}
|
|
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
|
|
return normalizeTableName(db.Statement.Schema.Table)
|
|
}
|
|
if db.Statement.Table != "" {
|
|
return normalizeTableName(db.Statement.Table)
|
|
}
|
|
if db.Statement.SQL.String() != "" {
|
|
return normalizeTableName(tableFromSQL(db.Statement.SQL.String()))
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (p *TenantPlugin) matchTable(table string) bool {
|
|
if p.applyAll {
|
|
return true
|
|
}
|
|
_, ok := p.tableSet[normalizeTableName(table)]
|
|
return ok
|
|
}
|
|
|
|
func normalizeTableName(name string) string {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
return ""
|
|
}
|
|
fields := strings.Fields(name)
|
|
if len(fields) > 0 {
|
|
name = fields[0]
|
|
}
|
|
name = strings.Trim(name, "`\"")
|
|
if idx := strings.LastIndex(name, "."); idx >= 0 {
|
|
name = name[idx+1:]
|
|
}
|
|
return strings.ToLower(strings.TrimSpace(name))
|
|
}
|
|
|
|
func tableFromSQL(sql string) string {
|
|
lower := strings.ToLower(sql)
|
|
idx := strings.Index(lower, " from ")
|
|
if idx == -1 {
|
|
return ""
|
|
}
|
|
rest := strings.TrimSpace(sql[idx+6:])
|
|
if rest == "" {
|
|
return ""
|
|
}
|
|
fields := strings.Fields(rest)
|
|
if len(fields) == 0 {
|
|
return ""
|
|
}
|
|
table := strings.Trim(fields[0], "`\"")
|
|
if strings.HasPrefix(table, "(") {
|
|
return ""
|
|
}
|
|
return table
|
|
}
|