diff --git a/docs/Task1.md b/docs/Task1.md index a729f40..00ed570 100644 --- a/docs/Task1.md +++ b/docs/Task1.md @@ -40,7 +40,7 @@ CREATE TABLE t_user ( status TINYINT DEFAULT 1 COMMENT '状态:0-禁用,1-正常', create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', - deleted TINYINT DEFAULT 0 COMMENT '软删除:0-未删,1-已删' + delFlag TINYINT DEFAULT 0 COMMENT '软删除:0-未删,1-已删' ); COMMENT ON TABLE t_user IS '用户基础信息表'; @@ -56,10 +56,10 @@ CREATE TABLE t_platform_user ( last_login_time TIMESTAMP COMMENT '最后登录时间', create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - deleted TINYINT DEFAULT 0, + DelFlag TINYINT DEFAULT 0, -- 联合唯一索引:同一平台的openid不能重复 UNIQUE (platform_type, platform_openid), - -- 外键关联用户表 + CONSTRAINT fk_platform_user_user_id FOREIGN KEY (user_id) REFERENCES t_user(id) ON DELETE CASCADE ); COMMENT ON TABLE t_platform_user IS '平台用户关联表(微信/抖音小程序用户信息)'; diff --git a/docs/tasks/task_detail_2026_03_22.md b/docs/tasks/task_detail_2026_03_22.md new file mode 100644 index 0000000..7beddf9 --- /dev/null +++ b/docs/tasks/task_detail_2026_03_22.md @@ -0,0 +1,42 @@ +# 任务执行摘要 + +## 会话 ID: 20260322-tenant +- [2026-03-22 14:45:56] +- **执行原因**: 增加租户ID多租户查询支持,支持从Header读取并按配置过滤。 +- **执行过程**: + 1. 新增租户配置与中间件,从请求头读取 X-tenantId 并绑定上下文。 + 2. 引入 GORM 多租户插件与工具函数,支持按表名单自动加 tenant_id 条件、手动覆盖/跳过。 + 3. 调整现有查询代码与部分 Raw SQL,使用带上下文的 DB 并按需拼接租户条件。 +- **执行结果**: 已实现可配置的租户过滤与临时覆盖能力,配置已写入配置文件。 + +# 任务执行摘要 + +## 会话 ID: 20260322-delflag +- [2026-03-22 15:00:37] +- **执行原因**: 解决查询中使用 delFlag 导致列不存在的错误。 +- **执行过程**: + 1. 排查 user/platform_user mapper 中软删除条件与更新字段。 + 2. 将查询与更新字段统一改为数据库字段名 del_flag。 +- **执行结果**: 查询条件与软删除更新已改为 del_flag,避免列不存在错误。 + +# 任务执行摘要 + +## 会话 ID: 20260322-tenant-login +- [2026-03-22 15:11:07] +- **执行原因**: 登录接口要求从 Header 获取 tenantId 并校验。 +- **执行过程**: + 1. 在登录接口中读取配置的 HeaderKey(默认 X-tenantId)。 + 2. tenantId 为空时直接返回 400。 + 3. Swagger 注解增加 Header 参数。 +- **执行结果**: 登录接口已强制要求 tenantId。 + +# 任务执行摘要 + +## 会话 ID: 20260322-tenant-helper +- [2026-03-22 15:13:58] +- **执行原因**: 抽取通用方法获取租户Header与tenantId。 +- **执行过程**: + 1. 新增 TenantHeaderKey 与 TenantIDFromHeader 工具方法。 + 2. 中间件与登录接口改为调用通用方法。 +- **执行结果**: tenantId 获取逻辑已统一复用。 + diff --git a/docs/tasks/task_detail_2026_03_23.md b/docs/tasks/task_detail_2026_03_23.md new file mode 100644 index 0000000..6bff954 --- /dev/null +++ b/docs/tasks/task_detail_2026_03_23.md @@ -0,0 +1,20 @@ +# 任务执行摘要 + +## 会话 ID: 20260323-unused-platformUserID +- [2026-03-23 10:45:49] +- **执行原因**: 修复 wechat_service.go 中未使用变量导致的编译失败。 +- **执行过程**: + 1. 定位 platformUserID 未被使用的定义与赋值。 + 2. 移除该变量及相关赋值。 +- **执行结果**: 编译错误已排除。 + +# 任务执行摘要 + +## 会话 ID: 20260323-tenant-error-const +- [2026-03-23 11:10:36] +- **执行原因**: 抽取 tenantId 缺失的错误文案为常量。 +- **执行过程**: + 1. 在 common/constants.go 新增 ErrTenantIDMissing 常量。 + 2. 登录接口引用该常量返回错误。 +- **执行结果**: 错误文案已集中管理并复用。 + diff --git a/server/common/base_mapper.go b/server/common/base_mapper.go index 2709b34..5e58adf 100644 --- a/server/common/base_mapper.go +++ b/server/common/base_mapper.go @@ -25,6 +25,9 @@ func NewBaseMapper[T any]() *BaseMapper[T] { // GetDB 获取数据库实例(允许子类覆盖) func (m *BaseMapper[T]) GetDB() *gorm.DB { + if db := CurrentDB(); db != nil { + return db + } return m.db } @@ -89,4 +92,4 @@ func (m *BaseMapper[T]) BatchUpsert(items []T, updateColumns []string) error { // BatchDelete 批量删除 func (m *BaseMapper[T]) BatchDelete(ids []string) error { return m.GetDB().Delete(new(T), "id IN ?", ids).Error -} \ No newline at end of file +} diff --git a/server/common/constants.go b/server/common/constants.go index 4ac3c6c..b9803dd 100644 --- a/server/common/constants.go +++ b/server/common/constants.go @@ -26,6 +26,12 @@ const ( HeaderTokenPrefix = "Bearer " ) +// 错误信息常量 +const ( + ErrTenantIDMissing = "未设置租户Id" + ErrPhonePwdMissing = "手机号和密码不能为空" +) + // 业务状态常量 const ( StateActive = "1" // 使用中 diff --git a/server/common/db.go b/server/common/db.go new file mode 100644 index 0000000..6f3e183 --- /dev/null +++ b/server/common/db.go @@ -0,0 +1,35 @@ +package common + +import ( + "server/config" + + "gorm.io/gorm" +) + +// CurrentDB 返回带请求上下文的 DB 实例 +func CurrentDB() *gorm.DB { + if config.DB == nil { + return nil + } + ctx := GetRequestContext() + if ctx == nil { + return config.DB + } + return config.DB.WithContext(ctx) +} + +// WithTenantID 临时覆盖租户ID +func WithTenantID(db *gorm.DB, tenantID string) *gorm.DB { + if db == nil { + return db + } + return db.Set(DBSettingTenantID, tenantID) +} + +// SkipTenant 临时跳过租户过滤 +func SkipTenant(db *gorm.DB) *gorm.DB { + if db == nil { + return db + } + return db.Set(DBSettingTenantSkip, true) +} diff --git a/server/common/tenant_context.go b/server/common/tenant_context.go new file mode 100644 index 0000000..4c0cef5 --- /dev/null +++ b/server/common/tenant_context.go @@ -0,0 +1,94 @@ +package common + +import ( + "context" + "fmt" + "runtime" + "strconv" + "strings" + "sync" +) + +type contextKey string + +const ( + // ContextTenantIDKey 用于在 Context 中存储租户ID + ContextTenantIDKey contextKey = "tenantId" + // DBSettingTenantID 用于在 GORM DB 上临时覆盖租户ID + DBSettingTenantID = "tenant_id" + // DBSettingTenantSkip 用于临时跳过租户过滤 + DBSettingTenantSkip = "tenant_skip" +) + +var requestContextStore sync.Map + +// BindRequestContext 绑定当前请求上下文到当前协程,返回清理函数 +func BindRequestContext(ctx context.Context) func() { + gid := goroutineID() + if gid == 0 { + return func() {} + } + requestContextStore.Store(gid, ctx) + return func() { + requestContextStore.Delete(gid) + } +} + +// GetRequestContext 获取当前协程绑定的请求上下文 +func GetRequestContext() context.Context { + gid := goroutineID() + if gid == 0 { + return nil + } + if v, ok := requestContextStore.Load(gid); ok { + if ctx, ok := v.(context.Context); ok { + return ctx + } + } + return nil +} + +// TenantIDFromContext 从 Context 中获取租户ID +func TenantIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + value := ctx.Value(ContextTenantIDKey) + if value == nil { + return "" + } + return strings.TrimSpace(toString(value)) +} + +// CurrentTenantID 获取当前协程绑定的租户ID +func CurrentTenantID() string { + return TenantIDFromContext(GetRequestContext()) +} + +func goroutineID() uint64 { + var buf [64]byte + n := runtime.Stack(buf[:], false) + if n <= 0 { + return 0 + } + fields := strings.Fields(string(buf[:n])) + if len(fields) < 2 { + return 0 + } + id, err := strconv.ParseUint(fields[1], 10, 64) + if err != nil { + return 0 + } + return id +} + +func toString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + default: + return fmt.Sprint(value) + } +} diff --git a/server/common/tenant_header.go b/server/common/tenant_header.go new file mode 100644 index 0000000..bbc72ac --- /dev/null +++ b/server/common/tenant_header.go @@ -0,0 +1,11 @@ +package common + +import "strings" + +// TenantIDFromHeader 从请求头获取租户ID +func TenantIDFromHeader(getHeader func(string) string) string { + if getHeader == nil { + return "" + } + return strings.TrimSpace(getHeader(TenantHeaderKey())) +} diff --git a/server/common/tenant_plugin.go b/server/common/tenant_plugin.go new file mode 100644 index 0000000..43fcc20 --- /dev/null +++ b/server/common/tenant_plugin.go @@ -0,0 +1,184 @@ +package common + +import ( + "fmt" + "strings" + + "server/config" + + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// InitTenantPlugin 初始化租户插件 +func InitTenantPlugin(db *gorm.DB) { + if db == nil { + return + } + plugin := NewTenantPlugin(config.AppConfig.Tenant) + if err := db.Use(plugin); err != nil { + LogError("租户插件初始化失败: %v", err) + return + } + Info("租户插件初始化完成 (enable=%v)", plugin.enabled) +} + +// TenantPlugin GORM 多租户插件 +type TenantPlugin struct { + enabled bool + column string + tableSet map[string]struct{} + applyAll bool +} + +// NewTenantPlugin 创建插件实例 +func NewTenantPlugin(cfg config.TenantConfig) *TenantPlugin { + column := strings.TrimSpace(cfg.Column) + if column == "" { + column = "tenant_id" + } + tableSet := make(map[string]struct{}) + applyAll := false + for _, table := range cfg.Tables { + name := normalizeTableName(table) + if name == "" { + continue + } + if name == "*" { + applyAll = true + continue + } + tableSet[name] = struct{}{} + } + return &TenantPlugin{ + enabled: cfg.Enable, + column: column, + tableSet: tableSet, + applyAll: applyAll, + } +} + +// Name 插件名 +func (p *TenantPlugin) Name() string { + return "tenant_plugin" +} + +// Initialize 注册回调 +func (p *TenantPlugin) Initialize(db *gorm.DB) error { + db.Callback().Query().Before("gorm:query").Register("tenant:query", p.before) + db.Callback().Row().Before("gorm:row").Register("tenant:row", p.before) + db.Callback().Raw().Before("gorm:raw").Register("tenant:raw", p.before) + db.Callback().Update().Before("gorm:update").Register("tenant:update", p.before) + db.Callback().Delete().Before("gorm:delete").Register("tenant:delete", p.before) + return nil +} + +func (p *TenantPlugin) before(db *gorm.DB) { + if !p.enabled || db == nil || db.Statement == nil { + return + } + if skip, ok := db.Get(DBSettingTenantSkip); ok { + if v, ok := skip.(bool); ok && v { + return + } + } + + tenantID, ok := p.tenantID(db) + if !ok { + return + } + + table := p.tableName(db) + if table == "" || !p.matchTable(table) { + return + } + + db.Statement.AddClause(clause.Where{ + Exprs: []clause.Expression{ + clause.Eq{ + Column: clause.Column{Name: p.column}, + Value: tenantID, + }, + }, + }) +} + +func (p *TenantPlugin) tenantID(db *gorm.DB) (string, bool) { + if db == nil { + return "", false + } + if value, ok := db.Get(DBSettingTenantID); ok { + id := strings.TrimSpace(fmt.Sprint(value)) + if id != "" { + return id, true + } + } + if db.Statement != nil && db.Statement.Context != nil { + id := TenantIDFromContext(db.Statement.Context) + if id != "" { + return id, true + } + } + return "", false +} + +func (p *TenantPlugin) tableName(db *gorm.DB) string { + if db == nil || db.Statement == nil { + return "" + } + if db.Statement.Schema != nil && db.Statement.Schema.Table != "" { + return normalizeTableName(db.Statement.Schema.Table) + } + if db.Statement.Table != "" { + return normalizeTableName(db.Statement.Table) + } + if db.Statement.SQL.String() != "" { + return normalizeTableName(tableFromSQL(db.Statement.SQL.String())) + } + return "" +} + +func (p *TenantPlugin) matchTable(table string) bool { + if p.applyAll { + return true + } + _, ok := p.tableSet[normalizeTableName(table)] + return ok +} + +func normalizeTableName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + fields := strings.Fields(name) + if len(fields) > 0 { + name = fields[0] + } + name = strings.Trim(name, "`\"") + if idx := strings.LastIndex(name, "."); idx >= 0 { + name = name[idx+1:] + } + return strings.ToLower(strings.TrimSpace(name)) +} + +func tableFromSQL(sql string) string { + lower := strings.ToLower(sql) + idx := strings.Index(lower, " from ") + if idx == -1 { + return "" + } + rest := strings.TrimSpace(sql[idx+6:]) + if rest == "" { + return "" + } + fields := strings.Fields(rest) + if len(fields) == 0 { + return "" + } + table := strings.Trim(fields[0], "`\"") + if strings.HasPrefix(table, "(") { + return "" + } + return table +} diff --git a/server/common/tenant_utils.go b/server/common/tenant_utils.go new file mode 100644 index 0000000..217c453 --- /dev/null +++ b/server/common/tenant_utils.go @@ -0,0 +1,41 @@ +package common + +import ( + "strings" + + "server/config" +) + +// TenantHeaderKey 获取租户Header字段名 +func TenantHeaderKey() string { + headerKey := strings.TrimSpace(config.AppConfig.Tenant.HeaderKey) + if headerKey == "" { + return "X-Tenant-Id" + } + return headerKey +} + +// IsTenantTable 判断表是否启用租户过滤 +func IsTenantTable(table string) bool { + cfg := config.AppConfig.Tenant + if !cfg.Enable { + return false + } + name := normalizeTableName(table) + if name == "" { + return false + } + for _, item := range cfg.Tables { + tableName := normalizeTableName(item) + if tableName == "" { + continue + } + if tableName == "*" { + return true + } + if strings.EqualFold(tableName, name) { + return true + } + } + return false +} diff --git a/server/config/config.dev.yaml b/server/config/config.dev.yaml index c1c5cc2..14b5037 100644 --- a/server/config/config.dev.yaml +++ b/server/config/config.dev.yaml @@ -13,6 +13,14 @@ security: header_key: X-App-Sign secret_key: yts@2025#secure +tenant: + enable: true + header_key: X-Tenant-Id + column: tenant_id + tables: + - t_user + - t_platform_user + payload_crypto: enable: true header_key: X-App-Encrypt diff --git a/server/config/config.go b/server/config/config.go index 8ca1675..3325894 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -22,6 +22,7 @@ type appConfig struct { Redis RedisConfig `yaml:"redis"` Wechat WechatConfig `yaml:"wechat"` AppConfig AppVersionConfig `yaml:"app_config"` + Tenant TenantConfig `yaml:"tenant"` } // LogConfig 日志配置 @@ -124,6 +125,14 @@ type AppVersionConfig struct { TenantId string `yaml:"tenantId"` } +// TenantConfig 多租户配置 +type TenantConfig struct { + Enable bool `yaml:"enable"` // 是否启用 + HeaderKey string `yaml:"header_key"` // Header 中的租户ID字段名 + Column string `yaml:"column"` // 表中的租户字段名 + Tables []string `yaml:"tables"` // 启用租户过滤的表 +} + // AppClientConfig 客户端版本配置 type AppClientConfig struct { MinVersion string `yaml:"min_version"` diff --git a/server/config/config.prod.yaml b/server/config/config.prod.yaml index dfa9a4b..e5572da 100644 --- a/server/config/config.prod.yaml +++ b/server/config/config.prod.yaml @@ -1,7 +1,7 @@ server: port: 8081 - worker_id: 1 # 工作机器ID (0-31),多实例部署需配置不同值 - datacenter_id: 0 # 数据中心ID (0-31),多机房部署需配置不同值 # 雪花算法机器ID (0-1023),分布式环境下不同实例需设置不同值,多实例部署时需手动配置 + worker_id: 1 # 工作机器ID (0-31),多实例部署需配置不同值 + datacenter_id: 0 # 数据中心ID (0-31),多机房部署需配置不同值 # 雪花算法机器ID (0-1023),分布式环境下不同实例需设置不同值,多实例部署时需手动配置 log: level: info @@ -13,17 +13,25 @@ security: header_key: X-App-Sign secret_key: yts@2025#secure +tenant: + enable: true + header_key: X-Tenant-Id + column: tenant_id + tables: + - t_user + - t_platform_user + payload_crypto: - enable: false + enable: true header_key: X-App-Encrypt - secret_key: "" + secret_key: "1" whitelist: - /swagger/ request: enable: false required: false response: - enable: false + enable: true required: false rate_limit: diff --git a/server/config/config.test.yaml b/server/config/config.test.yaml index 1b28efc..b98d3ac 100644 --- a/server/config/config.test.yaml +++ b/server/config/config.test.yaml @@ -13,6 +13,14 @@ security: header_key: X-App-Sign secret_key: yts@2025#secure +tenant: + enable: false + header_key: X-tenantId + column: tenant_id + tables: + - yx_user_score + - yx_volunteer + payload_crypto: enable: false header_key: X-App-Encrypt diff --git a/server/main.go b/server/main.go index fc3ca92..3dfff8b 100644 --- a/server/main.go +++ b/server/main.go @@ -60,6 +60,7 @@ func main() { // 初始化数据库 config.InitDB() common.Info("数据库初始化完成") + common.InitTenantPlugin(config.DB) // 初始化Redis config.InitRedis() @@ -86,6 +87,7 @@ func main() { api := r.Group("/api") // 中间件顺序: 参数加解密 -> 安全校验 -> 限流 -> 登录鉴权 + api.Use(middleware.TenantMiddleware()) api.Use(middleware.PayloadCryptoMiddleware()) api.Use(middleware.SecurityMiddleware()) api.Use(middleware.RateLimitMiddleware()) diff --git a/server/middleware/tenant.go b/server/middleware/tenant.go new file mode 100644 index 0000000..9f6393e --- /dev/null +++ b/server/middleware/tenant.go @@ -0,0 +1,29 @@ +package middleware + +import ( + "context" + + "server/common" + + "github.com/gin-gonic/gin" +) + +// TenantMiddleware 解析租户ID并绑定到请求上下文 +func TenantMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + tenantID := common.TenantIDFromHeader(c.GetHeader) + ctx := c.Request.Context() + if tenantID != "" { + ctx = context.WithValue(ctx, common.ContextTenantIDKey, tenantID) + } + + cleanup := common.BindRequestContext(ctx) + defer cleanup() + + if ctx != c.Request.Context() { + c.Request = c.Request.WithContext(ctx) + } + + c.Next() + } +} diff --git a/server/modules/api/controller/user_auth_controller.go b/server/modules/api/controller/user_auth_controller.go index d495673..49eaac8 100644 --- a/server/modules/api/controller/user_auth_controller.go +++ b/server/modules/api/controller/user_auth_controller.go @@ -27,13 +27,20 @@ func (ctrl *OpenAuthController) RegisterRoutes(r *gin.RouterGroup) { // @Tags 对外接口 // @Accept json // @Produce json +// @Param X-tenantId header string true "租户ID" // @Param request body dto.UserPasswordLoginRequest true "登录信息" // @Success 200 {object} common.Response // @Router /open/user/login [post] func (ctrl *OpenAuthController) LoginByPhone(c *gin.Context) { + tenantID := common.TenantIDFromHeader(c.GetHeader) + if tenantID == "" { + common.Error(c, 400, common.ErrTenantIDMissing) + return + } + var req apiDto.UserPasswordLoginRequest if err := c.ShouldBindJSON(&req); err != nil { - common.Error(c, 400, "手机号和密码不能为空") + common.Error(c, 400, common.ErrPhonePwdMissing) return } loginUser, token, err := ctrl.userService.LoginByPhonePassword(req.Phone, req.Password) diff --git a/server/modules/api/controller/wechat_controller.go b/server/modules/api/controller/wechat_controller.go index db721c5..9f9fb8d 100644 --- a/server/modules/api/controller/wechat_controller.go +++ b/server/modules/api/controller/wechat_controller.go @@ -32,6 +32,11 @@ func (ctrl *WechatMiniProgramController) RegisterRoutes(r *gin.RouterGroup) { // @Router /open/wechat/mini/login [post] func (ctrl *WechatMiniProgramController) MiniLogin(c *gin.Context) { var req apiDto.WechatMiniLoginRequest + tenantID := common.TenantIDFromHeader(c.GetHeader) + if tenantID == "" { + common.Error(c, 400, common.ErrTenantIDMissing) + return + } if err := c.ShouldBindJSON(&req); err != nil { common.Error(c, 400, "参数错误") return diff --git a/server/modules/api/dto/app_config_dto.go b/server/modules/api/dto/app_config_dto.go index 71389e2..7af5597 100644 --- a/server/modules/api/dto/app_config_dto.go +++ b/server/modules/api/dto/app_config_dto.go @@ -20,6 +20,7 @@ type AppClientConfig struct { // AppEndpointConfig 接口或 WebView 配置 type AppEndpointConfig struct { + TenantId string `json:"tenantId"` BaseURL string `json:"baseUrl"` Version string `json:"version"` MinClientVersion string `json:"minClientVersion"` diff --git a/server/modules/api/dto/wechat_dto.go b/server/modules/api/dto/wechat_dto.go index 85f1723..3d37cc0 100644 --- a/server/modules/api/dto/wechat_dto.go +++ b/server/modules/api/dto/wechat_dto.go @@ -18,13 +18,13 @@ type WechatMiniLoginRequest struct { // WechatMiniLoginResponse 微信小程序登录响应 type WechatMiniLoginResponse struct { - UserID int64 `json:"userId"` - PlatformUserID int64 `json:"platformUserId"` - OpenID string `json:"openid"` - UnionID string `json:"unionid"` - SessionKey string `json:"sessionKey"` - Phone string `json:"phone"` - Token string `json:"token"` - IsNewPlatform bool `json:"isNewPlatform"` - IsNewUser bool `json:"isNewUser"` + UserID int64 `json:"userId"` + // PlatformUserID int64 `json:"platformUserId"` + // OpenID string `json:"openid"` + // UnionID string `json:"unionid"` + // SessionKey string `json:"sessionKey"` + // Phone string `json:"phone"` + Token string `json:"token"` + IsNewPlatform bool `json:"isNewPlatform"` + IsNewUser bool `json:"isNewUser"` } diff --git a/server/modules/api/service/app_config_service.go b/server/modules/api/service/app_config_service.go index 36668be..100c24f 100644 --- a/server/modules/api/service/app_config_service.go +++ b/server/modules/api/service/app_config_service.go @@ -22,6 +22,7 @@ func (s *AppConfigService) GetConfig() apiDto.AppConfigResponse { ForceUpdate: cfg.App.ForceUpdate, }, API: apiDto.AppEndpointConfig{ + TenantId: cfg.TenantId, BaseURL: cfg.API.BaseURL, Version: cfg.API.Version, MinClientVersion: cfg.API.MinClientVersion, diff --git a/server/modules/api/service/wechat_service.go b/server/modules/api/service/wechat_service.go index 730e361..d39e51b 100644 --- a/server/modules/api/service/wechat_service.go +++ b/server/modules/api/service/wechat_service.go @@ -62,7 +62,6 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a isNewPlatform := false isNewUser := false var userID int64 - var platformUserID int64 var phone string if req.PhoneCode != "" { @@ -132,13 +131,11 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a } userID = user.ID - platformUserID = platform.ID } else { return nil, err } } else { userID = platformUser.UserID - platformUserID = platformUser.ID fields := map[string]interface{}{ "platform_session_key": session.SessionKey, "last_login_time": now, @@ -178,15 +175,15 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a } return &apiDto.WechatMiniLoginResponse{ - UserID: userID, - PlatformUserID: platformUserID, - OpenID: session.OpenID, - UnionID: session.UnionID, - SessionKey: session.SessionKey, - Phone: phone, - Token: token, - IsNewPlatform: isNewPlatform, - IsNewUser: isNewUser, + UserID: userID, + // PlatformUserID: platformUserID, + // OpenID: session.OpenID, + // UnionID: session.UnionID, + // SessionKey: session.SessionKey, + // Phone: phone, + Token: token, + IsNewPlatform: isNewPlatform, + IsNewUser: isNewUser, }, nil } diff --git a/server/modules/user/mapper/platform_user_mapper.go b/server/modules/user/mapper/platform_user_mapper.go index f23c95e..c267b23 100644 --- a/server/modules/user/mapper/platform_user_mapper.go +++ b/server/modules/user/mapper/platform_user_mapper.go @@ -2,27 +2,26 @@ package mapper import ( - "server/config" + "server/common" "server/modules/user/entity" "gorm.io/gorm" ) type PlatformUserMapper struct { - db *gorm.DB } func NewPlatformUserMapper() *PlatformUserMapper { - return &PlatformUserMapper{db: config.DB} + return &PlatformUserMapper{} } func (m *PlatformUserMapper) baseDB() *gorm.DB { - return m.db + return common.CurrentDB() } // GetDB 获取数据库实例,默认过滤软删除 func (m *PlatformUserMapper) GetDB() *gorm.DB { - return m.baseDB().Where("delFlag = 0") + return m.baseDB().Where("del_flag = 0") } // FindAll 分页查询 @@ -73,5 +72,5 @@ func (m *PlatformUserMapper) UpdateFields(id int64, fields map[string]interface{ // Delete 逻辑删除 func (m *PlatformUserMapper) Delete(id int64) error { - return m.baseDB().Model(&entity.PlatformUser{}).Where("id = ?", id).Update("delFlag", 1).Error + return m.baseDB().Model(&entity.PlatformUser{}).Where("id = ?", id).Update("del_flag", 1).Error } diff --git a/server/modules/user/mapper/user_mapper.go b/server/modules/user/mapper/user_mapper.go index c5cfccc..d279af9 100644 --- a/server/modules/user/mapper/user_mapper.go +++ b/server/modules/user/mapper/user_mapper.go @@ -2,27 +2,26 @@ package mapper import ( - "server/config" + "server/common" "server/modules/user/entity" "gorm.io/gorm" ) type UserMapper struct { - db *gorm.DB } func NewUserMapper() *UserMapper { - return &UserMapper{db: config.DB} + return &UserMapper{} } func (m *UserMapper) baseDB() *gorm.DB { - return m.db + return common.CurrentDB() } // GetDB 获取数据库实例,默认过滤软删除 func (m *UserMapper) GetDB() *gorm.DB { - return m.baseDB().Where("delFlag = 0") + return m.baseDB().Where("del_flag = 0") } // FindAll 分页查询 @@ -66,5 +65,5 @@ func (m *UserMapper) UpdateFields(id int64, fields map[string]interface{}) error // Delete 逻辑删除 func (m *UserMapper) Delete(id int64) error { - return m.baseDB().Model(&entity.User{}).Where("id = ?", id).Update("delFlag", 1).Error + return m.baseDB().Model(&entity.User{}).Where("id = ?", id).Update("del_flag", 1).Error } diff --git a/server/modules/user/service/user_score_service.go b/server/modules/user/service/user_score_service.go index b4a3dc8..0f05220 100644 --- a/server/modules/user/service/user_score_service.go +++ b/server/modules/user/service/user_score_service.go @@ -33,7 +33,7 @@ type UserScoreService struct { func (s *UserScoreService) GetActiveScoreID(userID string) (string, error) { var score entity.YxUserScore // 明确指定字段,提高可读性 - err := config.DB.Model(&entity.YxUserScore{}). + err := common.CurrentDB().Model(&entity.YxUserScore{}). Where("create_by = ? AND state = ?", userID, "1"). Select("id"). First(&score).Error @@ -50,7 +50,7 @@ func (s *UserScoreService) GetActiveScoreID(userID string) (string, error) { func (s *UserScoreService) GetActiveScoreByID(userID string) (entity.YxUserScore, error) { var score entity.YxUserScore // 明确指定字段,提高可读性 - err := config.DB.Model(&entity.YxUserScore{}). + err := common.CurrentDB().Model(&entity.YxUserScore{}). Where("create_by = ? AND state = ?", userID, "1"). First(&score).Error if err != nil { @@ -68,7 +68,7 @@ func (s *UserScoreService) GetActiveScoreByUserID(userID string) (vo.UserScoreVO scoreRedisData, err := config.RDB.Get(context.Background(), common.RedisUserScorePrefix+userID).Result() if err != nil { // 明确指定字段,提高可读性 - err := config.DB.Model(&entity.YxUserScore{}). + err := common.CurrentDB().Model(&entity.YxUserScore{}). Where("create_by = ? AND state = ?", userID, "1"). First(&score).Error if err != nil { @@ -99,7 +99,7 @@ func (s *UserScoreService) GetActiveScoreByUserID(userID string) (vo.UserScoreVO func (s *UserScoreService) GetByID(id string) (vo.UserScoreVO, error) { var score entity.YxUserScore - err := config.DB.Model(&entity.YxUserScore{}). + err := common.CurrentDB().Model(&entity.YxUserScore{}). Where("id = ?", id). First(&score).Error if err != nil { @@ -116,7 +116,7 @@ func (s *UserScoreService) ListByUser(userID string, page, size int) ([]vo.UserS var scores []entity.YxUserScore var total int64 - query := config.DB.Model(&entity.YxUserScore{}).Where("create_by = ?", userID) + query := common.CurrentDB().Model(&entity.YxUserScore{}).Where("create_by = ?", userID) if err := query.Count(&total).Error; err != nil { return nil, 0, fmt.Errorf("查询总数失败: %w", err) @@ -360,7 +360,7 @@ func (s *UserScoreService) SaveUserScore(req *yxDto.SaveScoreRequest) (vo.UserSc entityItem.UpdateTime = time.Now() // 3. 执行保存操作(可以包含事务) - tx := config.DB.Begin() + tx := common.CurrentDB().Begin() defer func() { if r := recover(); r != nil { fmt.Printf("【PANIC】事务执行过程中发生panic: %v", r) diff --git a/server/modules/yx/mapper/yx_calculation_major_mapper.go b/server/modules/yx/mapper/yx_calculation_major_mapper.go index a6d8f66..410d091 100644 --- a/server/modules/yx/mapper/yx_calculation_major_mapper.go +++ b/server/modules/yx/mapper/yx_calculation_major_mapper.go @@ -4,7 +4,6 @@ package mapper import ( "fmt" "server/common" - "server/config" "server/modules/yx/dto" "server/modules/yx/entity" "strings" @@ -85,6 +84,10 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery) baseSQL += " AND cm.major_name like ?" params = append(params, query.Keyword) } + if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) { + baseSQL += " AND cm.tenant_id = ?" + params = append(params, tenantID) + } // 3. 优化后的总数量COUNT SQL countSQL := fmt.Sprintf(` @@ -178,6 +181,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery) // 整体开始时间 totalStartTime := time.Now() + db := common.CurrentDB() wg.Add(3) @@ -186,7 +190,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery) defer wg.Done() // 记录该协程单独的开始时间 start := time.Now() - countErr = config.DB.Raw(countSQL, params...).Count(&total).Error + countErr = db.Raw(countSQL, params...).Count(&total).Error // 计算该协程耗时,通过互斥锁安全写入共享变量 mu.Lock() queryCost.CountCost = time.Since(start) @@ -198,7 +202,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery) defer wg.Done() // 记录该协程单独的开始时间 start := time.Now() - probCountErr = config.DB.Raw(probCountSQL, params...).Scan(&probCount).Error + probCountErr = db.Raw(probCountSQL, params...).Scan(&probCount).Error // 计算该协程耗时,通过互斥锁安全写入共享变量 mu.Lock() queryCost.ProbCountCost = time.Since(start) @@ -210,7 +214,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery) defer wg.Done() // 记录该协程单独的开始时间 start := time.Now() - queryErr = config.DB.Raw(mainSQL, params...).Scan(&items).Error + queryErr = db.Raw(mainSQL, params...).Scan(&items).Error // 计算该协程耗时,通过互斥锁安全写入共享变量 mu.Lock() queryCost.QueryCost = time.Since(start) @@ -336,6 +340,11 @@ func (m *YxCalculationMajorMapper) FindRecommendList1(query dto.SchoolMajorQuery sql += " AND cm.main_subjects = ?" params = append(params, query.MainSubjects) } + if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) { + countSQL += " AND cm.tenant_id = ?" + sql += " AND cm.tenant_id = ?" + params = append(params, tenantID) + } // 录取概率 switch query.Probability { @@ -355,19 +364,20 @@ func (m *YxCalculationMajorMapper) FindRecommendList1(query dto.SchoolMajorQuery var wg sync.WaitGroup var countErr, queryErr error + db := common.CurrentDB() wg.Add(2) // 协程1:COUNT 查询 go func() { defer wg.Done() - countErr = config.DB.Raw(countSQL, params...).Count(&total).Error + countErr = db.Raw(countSQL, params...).Count(&total).Error }() // 协程2:主查询 go func() { defer wg.Done() sql += fmt.Sprintf(" LIMIT %d OFFSET %d", query.Size, (query.Page-1)*query.Size) - queryErr = config.DB.Raw(sql, params...).Scan(&items).Error + queryErr = db.Raw(sql, params...).Scan(&items).Error }() wg.Wait() if countErr != nil || queryErr != nil { @@ -405,11 +415,14 @@ func (m *YxCalculationMajorMapper) FindListByCompositeKeys(tableName string, key db = db.Table(tableName) } - sql := "SELECT * FROM " + tableName + " WHERE score_id = ? AND (school_code, major_code, enrollment_code) IN (" + sql := "SELECT * FROM " + tableName + " WHERE score_id = ?" var params []interface{} - - // 将 score_id 作为第一个参数 params = append(params, scoreId) + if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) { + sql += " AND tenant_id = ?" + params = append(params, tenantID) + } + sql += " AND (school_code, major_code, enrollment_code) IN (" for i, key := range keys { parts := strings.Split(key, "_") @@ -478,11 +491,16 @@ func (m *YxCalculationMajorMapper) FindDtoListByCompositeKeys(tableName string, LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id LEFT JOIN yx_school s ON s.id = sc.school_id - WHERE cm.score_id = ? AND (cm.school_code, cm.major_code, cm.enrollment_code) IN ( + WHERE cm.score_id = ? `, tableName) var params []interface{} params = append(params, scoreId) + if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) { + sqlStr += " AND cm.tenant_id = ?" + params = append(params, tenantID) + } + sqlStr += " AND (cm.school_code, cm.major_code, cm.enrollment_code) IN (" // Build IN clause var tuples []string diff --git a/server/modules/yx/mapper/yx_history_score_control_line_mapper.go b/server/modules/yx/mapper/yx_history_score_control_line_mapper.go index cb317d7..814413c 100644 --- a/server/modules/yx/mapper/yx_history_score_control_line_mapper.go +++ b/server/modules/yx/mapper/yx_history_score_control_line_mapper.go @@ -1,7 +1,7 @@ package mapper import ( - "server/config" + "server/common" "server/modules/yx/entity" ) @@ -25,7 +25,7 @@ func (m *YxHistoryScoreControlLineMapper) SelectByYearAndCategory(year, professi categories = []string{professionalCategory} } - err := config.DB.Model(&entity.YxHistoryScoreControlLine{}). + err := common.CurrentDB().Model(&entity.YxHistoryScoreControlLine{}). Where("year = ?", year). Where("professional_category IN ?", categories). Where("category = ?", category). diff --git a/server/modules/yx/service/yx_history_major_enroll_service.go b/server/modules/yx/service/yx_history_major_enroll_service.go index 8861fac..5d0fbf4 100644 --- a/server/modules/yx/service/yx_history_major_enroll_service.go +++ b/server/modules/yx/service/yx_history_major_enroll_service.go @@ -3,7 +3,6 @@ package service import ( "server/common" - "server/config" "server/modules/yx/dto" "server/modules/yx/entity" "server/modules/yx/mapper" @@ -109,8 +108,12 @@ func (s *YxHistoryMajorEnrollService) ListBySchoolCodesAndMajorNames( sql += " AND hme.year IN ?" params = append(params, years) } + if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable("yx_history_major_enroll") { + sql += " AND hme.tenant_id = ?" + params = append(params, tenantID) + } sql += " ORDER BY hme.year DESC" - err := config.DB.Raw(sql, params...).Scan(&items).Error + err := common.CurrentDB().Raw(sql, params...).Scan(&items).Error return items, err }