feat: 修复部分功能

This commit is contained in:
zwt13703 2026-03-23 16:44:07 +08:00
parent f7cb916e01
commit 009a290135
28 changed files with 593 additions and 63 deletions

View File

@ -40,7 +40,7 @@ CREATE TABLE t_user (
status TINYINT DEFAULT 1 COMMENT '状态0-禁用1-正常', status TINYINT DEFAULT 1 COMMENT '状态0-禁用1-正常',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
deleted TINYINT DEFAULT 0 COMMENT '软删除0-未删1-已删' delFlag TINYINT DEFAULT 0 COMMENT '软删除0-未删1-已删'
); );
COMMENT ON TABLE t_user IS '用户基础信息表'; COMMENT ON TABLE t_user IS '用户基础信息表';
@ -56,10 +56,10 @@ CREATE TABLE t_platform_user (
last_login_time TIMESTAMP COMMENT '最后登录时间', last_login_time TIMESTAMP COMMENT '最后登录时间',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
deleted TINYINT DEFAULT 0, DelFlag TINYINT DEFAULT 0,
-- 联合唯一索引同一平台的openid不能重复 -- 联合唯一索引同一平台的openid不能重复
UNIQUE (platform_type, platform_openid), UNIQUE (platform_type, platform_openid),
-- 外键关联用户表
CONSTRAINT fk_platform_user_user_id FOREIGN KEY (user_id) REFERENCES t_user(id) ON DELETE CASCADE CONSTRAINT fk_platform_user_user_id FOREIGN KEY (user_id) REFERENCES t_user(id) ON DELETE CASCADE
); );
COMMENT ON TABLE t_platform_user IS '平台用户关联表(微信/抖音小程序用户信息)'; COMMENT ON TABLE t_platform_user IS '平台用户关联表(微信/抖音小程序用户信息)';

View File

@ -0,0 +1,42 @@
# 任务执行摘要
## 会话 ID: 20260322-tenant
- [2026-03-22 14:45:56]
- **执行原因**: 增加租户ID多租户查询支持支持从Header读取并按配置过滤。
- **执行过程**:
1. 新增租户配置与中间件,从请求头读取 X-tenantId 并绑定上下文。
2. 引入 GORM 多租户插件与工具函数,支持按表名单自动加 tenant_id 条件、手动覆盖/跳过。
3. 调整现有查询代码与部分 Raw SQL使用带上下文的 DB 并按需拼接租户条件。
- **执行结果**: 已实现可配置的租户过滤与临时覆盖能力,配置已写入配置文件。
# 任务执行摘要
## 会话 ID: 20260322-delflag
- [2026-03-22 15:00:37]
- **执行原因**: 解决查询中使用 delFlag 导致列不存在的错误。
- **执行过程**:
1. 排查 user/platform_user mapper 中软删除条件与更新字段。
2. 将查询与更新字段统一改为数据库字段名 del_flag。
- **执行结果**: 查询条件与软删除更新已改为 del_flag避免列不存在错误。
# 任务执行摘要
## 会话 ID: 20260322-tenant-login
- [2026-03-22 15:11:07]
- **执行原因**: 登录接口要求从 Header 获取 tenantId 并校验。
- **执行过程**:
1. 在登录接口中读取配置的 HeaderKey默认 X-tenantId
2. tenantId 为空时直接返回 400。
3. Swagger 注解增加 Header 参数。
- **执行结果**: 登录接口已强制要求 tenantId。
# 任务执行摘要
## 会话 ID: 20260322-tenant-helper
- [2026-03-22 15:13:58]
- **执行原因**: 抽取通用方法获取租户Header与tenantId。
- **执行过程**:
1. 新增 TenantHeaderKey 与 TenantIDFromHeader 工具方法。
2. 中间件与登录接口改为调用通用方法。
- **执行结果**: tenantId 获取逻辑已统一复用。

View File

@ -0,0 +1,20 @@
# 任务执行摘要
## 会话 ID: 20260323-unused-platformUserID
- [2026-03-23 10:45:49]
- **执行原因**: 修复 wechat_service.go 中未使用变量导致的编译失败。
- **执行过程**:
1. 定位 platformUserID 未被使用的定义与赋值。
2. 移除该变量及相关赋值。
- **执行结果**: 编译错误已排除。
# 任务执行摘要
## 会话 ID: 20260323-tenant-error-const
- [2026-03-23 11:10:36]
- **执行原因**: 抽取 tenantId 缺失的错误文案为常量。
- **执行过程**:
1. 在 common/constants.go 新增 ErrTenantIDMissing 常量。
2. 登录接口引用该常量返回错误。
- **执行结果**: 错误文案已集中管理并复用。

View File

@ -25,6 +25,9 @@ func NewBaseMapper[T any]() *BaseMapper[T] {
// GetDB 获取数据库实例(允许子类覆盖) // GetDB 获取数据库实例(允许子类覆盖)
func (m *BaseMapper[T]) GetDB() *gorm.DB { func (m *BaseMapper[T]) GetDB() *gorm.DB {
if db := CurrentDB(); db != nil {
return db
}
return m.db return m.db
} }

View File

@ -26,6 +26,12 @@ const (
HeaderTokenPrefix = "Bearer " HeaderTokenPrefix = "Bearer "
) )
// 错误信息常量
const (
ErrTenantIDMissing = "未设置租户Id"
ErrPhonePwdMissing = "手机号和密码不能为空"
)
// 业务状态常量 // 业务状态常量
const ( const (
StateActive = "1" // 使用中 StateActive = "1" // 使用中

35
server/common/db.go Normal file
View File

@ -0,0 +1,35 @@
package common
import (
"server/config"
"gorm.io/gorm"
)
// CurrentDB 返回带请求上下文的 DB 实例
func CurrentDB() *gorm.DB {
if config.DB == nil {
return nil
}
ctx := GetRequestContext()
if ctx == nil {
return config.DB
}
return config.DB.WithContext(ctx)
}
// WithTenantID 临时覆盖租户ID
func WithTenantID(db *gorm.DB, tenantID string) *gorm.DB {
if db == nil {
return db
}
return db.Set(DBSettingTenantID, tenantID)
}
// SkipTenant 临时跳过租户过滤
func SkipTenant(db *gorm.DB) *gorm.DB {
if db == nil {
return db
}
return db.Set(DBSettingTenantSkip, true)
}

View File

@ -0,0 +1,94 @@
package common
import (
"context"
"fmt"
"runtime"
"strconv"
"strings"
"sync"
)
type contextKey string
const (
// ContextTenantIDKey 用于在 Context 中存储租户ID
ContextTenantIDKey contextKey = "tenantId"
// DBSettingTenantID 用于在 GORM DB 上临时覆盖租户ID
DBSettingTenantID = "tenant_id"
// DBSettingTenantSkip 用于临时跳过租户过滤
DBSettingTenantSkip = "tenant_skip"
)
var requestContextStore sync.Map
// BindRequestContext 绑定当前请求上下文到当前协程,返回清理函数
func BindRequestContext(ctx context.Context) func() {
gid := goroutineID()
if gid == 0 {
return func() {}
}
requestContextStore.Store(gid, ctx)
return func() {
requestContextStore.Delete(gid)
}
}
// GetRequestContext 获取当前协程绑定的请求上下文
func GetRequestContext() context.Context {
gid := goroutineID()
if gid == 0 {
return nil
}
if v, ok := requestContextStore.Load(gid); ok {
if ctx, ok := v.(context.Context); ok {
return ctx
}
}
return nil
}
// TenantIDFromContext 从 Context 中获取租户ID
func TenantIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
value := ctx.Value(ContextTenantIDKey)
if value == nil {
return ""
}
return strings.TrimSpace(toString(value))
}
// CurrentTenantID 获取当前协程绑定的租户ID
func CurrentTenantID() string {
return TenantIDFromContext(GetRequestContext())
}
func goroutineID() uint64 {
var buf [64]byte
n := runtime.Stack(buf[:], false)
if n <= 0 {
return 0
}
fields := strings.Fields(string(buf[:n]))
if len(fields) < 2 {
return 0
}
id, err := strconv.ParseUint(fields[1], 10, 64)
if err != nil {
return 0
}
return id
}
func toString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprint(value)
}
}

View File

@ -0,0 +1,11 @@
package common
import "strings"
// TenantIDFromHeader 从请求头获取租户ID
func TenantIDFromHeader(getHeader func(string) string) string {
if getHeader == nil {
return ""
}
return strings.TrimSpace(getHeader(TenantHeaderKey()))
}

View File

@ -0,0 +1,184 @@
package common
import (
"fmt"
"strings"
"server/config"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// InitTenantPlugin 初始化租户插件
func InitTenantPlugin(db *gorm.DB) {
if db == nil {
return
}
plugin := NewTenantPlugin(config.AppConfig.Tenant)
if err := db.Use(plugin); err != nil {
LogError("租户插件初始化失败: %v", err)
return
}
Info("租户插件初始化完成 (enable=%v)", plugin.enabled)
}
// TenantPlugin GORM 多租户插件
type TenantPlugin struct {
enabled bool
column string
tableSet map[string]struct{}
applyAll bool
}
// NewTenantPlugin 创建插件实例
func NewTenantPlugin(cfg config.TenantConfig) *TenantPlugin {
column := strings.TrimSpace(cfg.Column)
if column == "" {
column = "tenant_id"
}
tableSet := make(map[string]struct{})
applyAll := false
for _, table := range cfg.Tables {
name := normalizeTableName(table)
if name == "" {
continue
}
if name == "*" {
applyAll = true
continue
}
tableSet[name] = struct{}{}
}
return &TenantPlugin{
enabled: cfg.Enable,
column: column,
tableSet: tableSet,
applyAll: applyAll,
}
}
// Name 插件名
func (p *TenantPlugin) Name() string {
return "tenant_plugin"
}
// Initialize 注册回调
func (p *TenantPlugin) Initialize(db *gorm.DB) error {
db.Callback().Query().Before("gorm:query").Register("tenant:query", p.before)
db.Callback().Row().Before("gorm:row").Register("tenant:row", p.before)
db.Callback().Raw().Before("gorm:raw").Register("tenant:raw", p.before)
db.Callback().Update().Before("gorm:update").Register("tenant:update", p.before)
db.Callback().Delete().Before("gorm:delete").Register("tenant:delete", p.before)
return nil
}
func (p *TenantPlugin) before(db *gorm.DB) {
if !p.enabled || db == nil || db.Statement == nil {
return
}
if skip, ok := db.Get(DBSettingTenantSkip); ok {
if v, ok := skip.(bool); ok && v {
return
}
}
tenantID, ok := p.tenantID(db)
if !ok {
return
}
table := p.tableName(db)
if table == "" || !p.matchTable(table) {
return
}
db.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Name: p.column},
Value: tenantID,
},
},
})
}
func (p *TenantPlugin) tenantID(db *gorm.DB) (string, bool) {
if db == nil {
return "", false
}
if value, ok := db.Get(DBSettingTenantID); ok {
id := strings.TrimSpace(fmt.Sprint(value))
if id != "" {
return id, true
}
}
if db.Statement != nil && db.Statement.Context != nil {
id := TenantIDFromContext(db.Statement.Context)
if id != "" {
return id, true
}
}
return "", false
}
func (p *TenantPlugin) tableName(db *gorm.DB) string {
if db == nil || db.Statement == nil {
return ""
}
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
return normalizeTableName(db.Statement.Schema.Table)
}
if db.Statement.Table != "" {
return normalizeTableName(db.Statement.Table)
}
if db.Statement.SQL.String() != "" {
return normalizeTableName(tableFromSQL(db.Statement.SQL.String()))
}
return ""
}
func (p *TenantPlugin) matchTable(table string) bool {
if p.applyAll {
return true
}
_, ok := p.tableSet[normalizeTableName(table)]
return ok
}
func normalizeTableName(name string) string {
name = strings.TrimSpace(name)
if name == "" {
return ""
}
fields := strings.Fields(name)
if len(fields) > 0 {
name = fields[0]
}
name = strings.Trim(name, "`\"")
if idx := strings.LastIndex(name, "."); idx >= 0 {
name = name[idx+1:]
}
return strings.ToLower(strings.TrimSpace(name))
}
func tableFromSQL(sql string) string {
lower := strings.ToLower(sql)
idx := strings.Index(lower, " from ")
if idx == -1 {
return ""
}
rest := strings.TrimSpace(sql[idx+6:])
if rest == "" {
return ""
}
fields := strings.Fields(rest)
if len(fields) == 0 {
return ""
}
table := strings.Trim(fields[0], "`\"")
if strings.HasPrefix(table, "(") {
return ""
}
return table
}

View File

@ -0,0 +1,41 @@
package common
import (
"strings"
"server/config"
)
// TenantHeaderKey 获取租户Header字段名
func TenantHeaderKey() string {
headerKey := strings.TrimSpace(config.AppConfig.Tenant.HeaderKey)
if headerKey == "" {
return "X-Tenant-Id"
}
return headerKey
}
// IsTenantTable 判断表是否启用租户过滤
func IsTenantTable(table string) bool {
cfg := config.AppConfig.Tenant
if !cfg.Enable {
return false
}
name := normalizeTableName(table)
if name == "" {
return false
}
for _, item := range cfg.Tables {
tableName := normalizeTableName(item)
if tableName == "" {
continue
}
if tableName == "*" {
return true
}
if strings.EqualFold(tableName, name) {
return true
}
}
return false
}

View File

@ -13,6 +13,14 @@ security:
header_key: X-App-Sign header_key: X-App-Sign
secret_key: yts@2025#secure secret_key: yts@2025#secure
tenant:
enable: true
header_key: X-Tenant-Id
column: tenant_id
tables:
- t_user
- t_platform_user
payload_crypto: payload_crypto:
enable: true enable: true
header_key: X-App-Encrypt header_key: X-App-Encrypt

View File

@ -22,6 +22,7 @@ type appConfig struct {
Redis RedisConfig `yaml:"redis"` Redis RedisConfig `yaml:"redis"`
Wechat WechatConfig `yaml:"wechat"` Wechat WechatConfig `yaml:"wechat"`
AppConfig AppVersionConfig `yaml:"app_config"` AppConfig AppVersionConfig `yaml:"app_config"`
Tenant TenantConfig `yaml:"tenant"`
} }
// LogConfig 日志配置 // LogConfig 日志配置
@ -124,6 +125,14 @@ type AppVersionConfig struct {
TenantId string `yaml:"tenantId"` TenantId string `yaml:"tenantId"`
} }
// TenantConfig 多租户配置
type TenantConfig struct {
Enable bool `yaml:"enable"` // 是否启用
HeaderKey string `yaml:"header_key"` // Header 中的租户ID字段名
Column string `yaml:"column"` // 表中的租户字段名
Tables []string `yaml:"tables"` // 启用租户过滤的表
}
// AppClientConfig 客户端版本配置 // AppClientConfig 客户端版本配置
type AppClientConfig struct { type AppClientConfig struct {
MinVersion string `yaml:"min_version"` MinVersion string `yaml:"min_version"`

View File

@ -1,7 +1,7 @@
server: server:
port: 8081 port: 8081
worker_id: 1 # 工作机器ID (0-31),多实例部署需配置不同值 worker_id: 1 # 工作机器ID (0-31),多实例部署需配置不同值
datacenter_id: 0 # 数据中心ID (0-31),多机房部署需配置不同值 # 雪花算法机器ID (0-1023),分布式环境下不同实例需设置不同值,多实例部署时需手动配置 datacenter_id: 0 # 数据中心ID (0-31),多机房部署需配置不同值 # 雪花算法机器ID (0-1023),分布式环境下不同实例需设置不同值,多实例部署时需手动配置
log: log:
level: info level: info
@ -13,17 +13,25 @@ security:
header_key: X-App-Sign header_key: X-App-Sign
secret_key: yts@2025#secure secret_key: yts@2025#secure
tenant:
enable: true
header_key: X-Tenant-Id
column: tenant_id
tables:
- t_user
- t_platform_user
payload_crypto: payload_crypto:
enable: false enable: true
header_key: X-App-Encrypt header_key: X-App-Encrypt
secret_key: "" secret_key: "1"
whitelist: whitelist:
- /swagger/ - /swagger/
request: request:
enable: false enable: false
required: false required: false
response: response:
enable: false enable: true
required: false required: false
rate_limit: rate_limit:

View File

@ -13,6 +13,14 @@ security:
header_key: X-App-Sign header_key: X-App-Sign
secret_key: yts@2025#secure secret_key: yts@2025#secure
tenant:
enable: false
header_key: X-tenantId
column: tenant_id
tables:
- yx_user_score
- yx_volunteer
payload_crypto: payload_crypto:
enable: false enable: false
header_key: X-App-Encrypt header_key: X-App-Encrypt

View File

@ -60,6 +60,7 @@ func main() {
// 初始化数据库 // 初始化数据库
config.InitDB() config.InitDB()
common.Info("数据库初始化完成") common.Info("数据库初始化完成")
common.InitTenantPlugin(config.DB)
// 初始化Redis // 初始化Redis
config.InitRedis() config.InitRedis()
@ -86,6 +87,7 @@ func main() {
api := r.Group("/api") api := r.Group("/api")
// 中间件顺序: 参数加解密 -> 安全校验 -> 限流 -> 登录鉴权 // 中间件顺序: 参数加解密 -> 安全校验 -> 限流 -> 登录鉴权
api.Use(middleware.TenantMiddleware())
api.Use(middleware.PayloadCryptoMiddleware()) api.Use(middleware.PayloadCryptoMiddleware())
api.Use(middleware.SecurityMiddleware()) api.Use(middleware.SecurityMiddleware())
api.Use(middleware.RateLimitMiddleware()) api.Use(middleware.RateLimitMiddleware())

View File

@ -0,0 +1,29 @@
package middleware
import (
"context"
"server/common"
"github.com/gin-gonic/gin"
)
// TenantMiddleware 解析租户ID并绑定到请求上下文
func TenantMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
tenantID := common.TenantIDFromHeader(c.GetHeader)
ctx := c.Request.Context()
if tenantID != "" {
ctx = context.WithValue(ctx, common.ContextTenantIDKey, tenantID)
}
cleanup := common.BindRequestContext(ctx)
defer cleanup()
if ctx != c.Request.Context() {
c.Request = c.Request.WithContext(ctx)
}
c.Next()
}
}

View File

@ -27,13 +27,20 @@ func (ctrl *OpenAuthController) RegisterRoutes(r *gin.RouterGroup) {
// @Tags 对外接口 // @Tags 对外接口
// @Accept json // @Accept json
// @Produce json // @Produce json
// @Param X-tenantId header string true "租户ID"
// @Param request body dto.UserPasswordLoginRequest true "登录信息" // @Param request body dto.UserPasswordLoginRequest true "登录信息"
// @Success 200 {object} common.Response // @Success 200 {object} common.Response
// @Router /open/user/login [post] // @Router /open/user/login [post]
func (ctrl *OpenAuthController) LoginByPhone(c *gin.Context) { func (ctrl *OpenAuthController) LoginByPhone(c *gin.Context) {
tenantID := common.TenantIDFromHeader(c.GetHeader)
if tenantID == "" {
common.Error(c, 400, common.ErrTenantIDMissing)
return
}
var req apiDto.UserPasswordLoginRequest var req apiDto.UserPasswordLoginRequest
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
common.Error(c, 400, "手机号和密码不能为空") common.Error(c, 400, common.ErrPhonePwdMissing)
return return
} }
loginUser, token, err := ctrl.userService.LoginByPhonePassword(req.Phone, req.Password) loginUser, token, err := ctrl.userService.LoginByPhonePassword(req.Phone, req.Password)

View File

@ -32,6 +32,11 @@ func (ctrl *WechatMiniProgramController) RegisterRoutes(r *gin.RouterGroup) {
// @Router /open/wechat/mini/login [post] // @Router /open/wechat/mini/login [post]
func (ctrl *WechatMiniProgramController) MiniLogin(c *gin.Context) { func (ctrl *WechatMiniProgramController) MiniLogin(c *gin.Context) {
var req apiDto.WechatMiniLoginRequest var req apiDto.WechatMiniLoginRequest
tenantID := common.TenantIDFromHeader(c.GetHeader)
if tenantID == "" {
common.Error(c, 400, common.ErrTenantIDMissing)
return
}
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
common.Error(c, 400, "参数错误") common.Error(c, 400, "参数错误")
return return

View File

@ -20,6 +20,7 @@ type AppClientConfig struct {
// AppEndpointConfig 接口或 WebView 配置 // AppEndpointConfig 接口或 WebView 配置
type AppEndpointConfig struct { type AppEndpointConfig struct {
TenantId string `json:"tenantId"`
BaseURL string `json:"baseUrl"` BaseURL string `json:"baseUrl"`
Version string `json:"version"` Version string `json:"version"`
MinClientVersion string `json:"minClientVersion"` MinClientVersion string `json:"minClientVersion"`

View File

@ -18,13 +18,13 @@ type WechatMiniLoginRequest struct {
// WechatMiniLoginResponse 微信小程序登录响应 // WechatMiniLoginResponse 微信小程序登录响应
type WechatMiniLoginResponse struct { type WechatMiniLoginResponse struct {
UserID int64 `json:"userId"` UserID int64 `json:"userId"`
PlatformUserID int64 `json:"platformUserId"` // PlatformUserID int64 `json:"platformUserId"`
OpenID string `json:"openid"` // OpenID string `json:"openid"`
UnionID string `json:"unionid"` // UnionID string `json:"unionid"`
SessionKey string `json:"sessionKey"` // SessionKey string `json:"sessionKey"`
Phone string `json:"phone"` // Phone string `json:"phone"`
Token string `json:"token"` Token string `json:"token"`
IsNewPlatform bool `json:"isNewPlatform"` IsNewPlatform bool `json:"isNewPlatform"`
IsNewUser bool `json:"isNewUser"` IsNewUser bool `json:"isNewUser"`
} }

View File

@ -22,6 +22,7 @@ func (s *AppConfigService) GetConfig() apiDto.AppConfigResponse {
ForceUpdate: cfg.App.ForceUpdate, ForceUpdate: cfg.App.ForceUpdate,
}, },
API: apiDto.AppEndpointConfig{ API: apiDto.AppEndpointConfig{
TenantId: cfg.TenantId,
BaseURL: cfg.API.BaseURL, BaseURL: cfg.API.BaseURL,
Version: cfg.API.Version, Version: cfg.API.Version,
MinClientVersion: cfg.API.MinClientVersion, MinClientVersion: cfg.API.MinClientVersion,

View File

@ -62,7 +62,6 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a
isNewPlatform := false isNewPlatform := false
isNewUser := false isNewUser := false
var userID int64 var userID int64
var platformUserID int64
var phone string var phone string
if req.PhoneCode != "" { if req.PhoneCode != "" {
@ -132,13 +131,11 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a
} }
userID = user.ID userID = user.ID
platformUserID = platform.ID
} else { } else {
return nil, err return nil, err
} }
} else { } else {
userID = platformUser.UserID userID = platformUser.UserID
platformUserID = platformUser.ID
fields := map[string]interface{}{ fields := map[string]interface{}{
"platform_session_key": session.SessionKey, "platform_session_key": session.SessionKey,
"last_login_time": now, "last_login_time": now,
@ -178,15 +175,15 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a
} }
return &apiDto.WechatMiniLoginResponse{ return &apiDto.WechatMiniLoginResponse{
UserID: userID, UserID: userID,
PlatformUserID: platformUserID, // PlatformUserID: platformUserID,
OpenID: session.OpenID, // OpenID: session.OpenID,
UnionID: session.UnionID, // UnionID: session.UnionID,
SessionKey: session.SessionKey, // SessionKey: session.SessionKey,
Phone: phone, // Phone: phone,
Token: token, Token: token,
IsNewPlatform: isNewPlatform, IsNewPlatform: isNewPlatform,
IsNewUser: isNewUser, IsNewUser: isNewUser,
}, nil }, nil
} }

View File

@ -2,27 +2,26 @@
package mapper package mapper
import ( import (
"server/config" "server/common"
"server/modules/user/entity" "server/modules/user/entity"
"gorm.io/gorm" "gorm.io/gorm"
) )
type PlatformUserMapper struct { type PlatformUserMapper struct {
db *gorm.DB
} }
func NewPlatformUserMapper() *PlatformUserMapper { func NewPlatformUserMapper() *PlatformUserMapper {
return &PlatformUserMapper{db: config.DB} return &PlatformUserMapper{}
} }
func (m *PlatformUserMapper) baseDB() *gorm.DB { func (m *PlatformUserMapper) baseDB() *gorm.DB {
return m.db return common.CurrentDB()
} }
// GetDB 获取数据库实例,默认过滤软删除 // GetDB 获取数据库实例,默认过滤软删除
func (m *PlatformUserMapper) GetDB() *gorm.DB { func (m *PlatformUserMapper) GetDB() *gorm.DB {
return m.baseDB().Where("delFlag = 0") return m.baseDB().Where("del_flag = 0")
} }
// FindAll 分页查询 // FindAll 分页查询
@ -73,5 +72,5 @@ func (m *PlatformUserMapper) UpdateFields(id int64, fields map[string]interface{
// Delete 逻辑删除 // Delete 逻辑删除
func (m *PlatformUserMapper) Delete(id int64) error { func (m *PlatformUserMapper) Delete(id int64) error {
return m.baseDB().Model(&entity.PlatformUser{}).Where("id = ?", id).Update("delFlag", 1).Error return m.baseDB().Model(&entity.PlatformUser{}).Where("id = ?", id).Update("del_flag", 1).Error
} }

View File

@ -2,27 +2,26 @@
package mapper package mapper
import ( import (
"server/config" "server/common"
"server/modules/user/entity" "server/modules/user/entity"
"gorm.io/gorm" "gorm.io/gorm"
) )
type UserMapper struct { type UserMapper struct {
db *gorm.DB
} }
func NewUserMapper() *UserMapper { func NewUserMapper() *UserMapper {
return &UserMapper{db: config.DB} return &UserMapper{}
} }
func (m *UserMapper) baseDB() *gorm.DB { func (m *UserMapper) baseDB() *gorm.DB {
return m.db return common.CurrentDB()
} }
// GetDB 获取数据库实例,默认过滤软删除 // GetDB 获取数据库实例,默认过滤软删除
func (m *UserMapper) GetDB() *gorm.DB { func (m *UserMapper) GetDB() *gorm.DB {
return m.baseDB().Where("delFlag = 0") return m.baseDB().Where("del_flag = 0")
} }
// FindAll 分页查询 // FindAll 分页查询
@ -66,5 +65,5 @@ func (m *UserMapper) UpdateFields(id int64, fields map[string]interface{}) error
// Delete 逻辑删除 // Delete 逻辑删除
func (m *UserMapper) Delete(id int64) error { func (m *UserMapper) Delete(id int64) error {
return m.baseDB().Model(&entity.User{}).Where("id = ?", id).Update("delFlag", 1).Error return m.baseDB().Model(&entity.User{}).Where("id = ?", id).Update("del_flag", 1).Error
} }

View File

@ -33,7 +33,7 @@ type UserScoreService struct {
func (s *UserScoreService) GetActiveScoreID(userID string) (string, error) { func (s *UserScoreService) GetActiveScoreID(userID string) (string, error) {
var score entity.YxUserScore var score entity.YxUserScore
// 明确指定字段,提高可读性 // 明确指定字段,提高可读性
err := config.DB.Model(&entity.YxUserScore{}). err := common.CurrentDB().Model(&entity.YxUserScore{}).
Where("create_by = ? AND state = ?", userID, "1"). Where("create_by = ? AND state = ?", userID, "1").
Select("id"). Select("id").
First(&score).Error First(&score).Error
@ -50,7 +50,7 @@ func (s *UserScoreService) GetActiveScoreID(userID string) (string, error) {
func (s *UserScoreService) GetActiveScoreByID(userID string) (entity.YxUserScore, error) { func (s *UserScoreService) GetActiveScoreByID(userID string) (entity.YxUserScore, error) {
var score entity.YxUserScore var score entity.YxUserScore
// 明确指定字段,提高可读性 // 明确指定字段,提高可读性
err := config.DB.Model(&entity.YxUserScore{}). err := common.CurrentDB().Model(&entity.YxUserScore{}).
Where("create_by = ? AND state = ?", userID, "1"). Where("create_by = ? AND state = ?", userID, "1").
First(&score).Error First(&score).Error
if err != nil { if err != nil {
@ -68,7 +68,7 @@ func (s *UserScoreService) GetActiveScoreByUserID(userID string) (vo.UserScoreVO
scoreRedisData, err := config.RDB.Get(context.Background(), common.RedisUserScorePrefix+userID).Result() scoreRedisData, err := config.RDB.Get(context.Background(), common.RedisUserScorePrefix+userID).Result()
if err != nil { if err != nil {
// 明确指定字段,提高可读性 // 明确指定字段,提高可读性
err := config.DB.Model(&entity.YxUserScore{}). err := common.CurrentDB().Model(&entity.YxUserScore{}).
Where("create_by = ? AND state = ?", userID, "1"). Where("create_by = ? AND state = ?", userID, "1").
First(&score).Error First(&score).Error
if err != nil { if err != nil {
@ -99,7 +99,7 @@ func (s *UserScoreService) GetActiveScoreByUserID(userID string) (vo.UserScoreVO
func (s *UserScoreService) GetByID(id string) (vo.UserScoreVO, error) { func (s *UserScoreService) GetByID(id string) (vo.UserScoreVO, error) {
var score entity.YxUserScore var score entity.YxUserScore
err := config.DB.Model(&entity.YxUserScore{}). err := common.CurrentDB().Model(&entity.YxUserScore{}).
Where("id = ?", id). Where("id = ?", id).
First(&score).Error First(&score).Error
if err != nil { if err != nil {
@ -116,7 +116,7 @@ func (s *UserScoreService) ListByUser(userID string, page, size int) ([]vo.UserS
var scores []entity.YxUserScore var scores []entity.YxUserScore
var total int64 var total int64
query := config.DB.Model(&entity.YxUserScore{}).Where("create_by = ?", userID) query := common.CurrentDB().Model(&entity.YxUserScore{}).Where("create_by = ?", userID)
if err := query.Count(&total).Error; err != nil { if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("查询总数失败: %w", err) return nil, 0, fmt.Errorf("查询总数失败: %w", err)
@ -360,7 +360,7 @@ func (s *UserScoreService) SaveUserScore(req *yxDto.SaveScoreRequest) (vo.UserSc
entityItem.UpdateTime = time.Now() entityItem.UpdateTime = time.Now()
// 3. 执行保存操作(可以包含事务) // 3. 执行保存操作(可以包含事务)
tx := config.DB.Begin() tx := common.CurrentDB().Begin()
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
fmt.Printf("【PANIC】事务执行过程中发生panic: %v", r) fmt.Printf("【PANIC】事务执行过程中发生panic: %v", r)

View File

@ -4,7 +4,6 @@ package mapper
import ( import (
"fmt" "fmt"
"server/common" "server/common"
"server/config"
"server/modules/yx/dto" "server/modules/yx/dto"
"server/modules/yx/entity" "server/modules/yx/entity"
"strings" "strings"
@ -85,6 +84,10 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
baseSQL += " AND cm.major_name like ?" baseSQL += " AND cm.major_name like ?"
params = append(params, query.Keyword) params = append(params, query.Keyword)
} }
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) {
baseSQL += " AND cm.tenant_id = ?"
params = append(params, tenantID)
}
// 3. 优化后的总数量COUNT SQL // 3. 优化后的总数量COUNT SQL
countSQL := fmt.Sprintf(` countSQL := fmt.Sprintf(`
@ -178,6 +181,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
// 整体开始时间 // 整体开始时间
totalStartTime := time.Now() totalStartTime := time.Now()
db := common.CurrentDB()
wg.Add(3) wg.Add(3)
@ -186,7 +190,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
defer wg.Done() defer wg.Done()
// 记录该协程单独的开始时间 // 记录该协程单独的开始时间
start := time.Now() start := time.Now()
countErr = config.DB.Raw(countSQL, params...).Count(&total).Error countErr = db.Raw(countSQL, params...).Count(&total).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量 // 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock() mu.Lock()
queryCost.CountCost = time.Since(start) queryCost.CountCost = time.Since(start)
@ -198,7 +202,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
defer wg.Done() defer wg.Done()
// 记录该协程单独的开始时间 // 记录该协程单独的开始时间
start := time.Now() start := time.Now()
probCountErr = config.DB.Raw(probCountSQL, params...).Scan(&probCount).Error probCountErr = db.Raw(probCountSQL, params...).Scan(&probCount).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量 // 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock() mu.Lock()
queryCost.ProbCountCost = time.Since(start) queryCost.ProbCountCost = time.Since(start)
@ -210,7 +214,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
defer wg.Done() defer wg.Done()
// 记录该协程单独的开始时间 // 记录该协程单独的开始时间
start := time.Now() start := time.Now()
queryErr = config.DB.Raw(mainSQL, params...).Scan(&items).Error queryErr = db.Raw(mainSQL, params...).Scan(&items).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量 // 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock() mu.Lock()
queryCost.QueryCost = time.Since(start) queryCost.QueryCost = time.Since(start)
@ -336,6 +340,11 @@ func (m *YxCalculationMajorMapper) FindRecommendList1(query dto.SchoolMajorQuery
sql += " AND cm.main_subjects = ?" sql += " AND cm.main_subjects = ?"
params = append(params, query.MainSubjects) params = append(params, query.MainSubjects)
} }
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) {
countSQL += " AND cm.tenant_id = ?"
sql += " AND cm.tenant_id = ?"
params = append(params, tenantID)
}
// 录取概率 // 录取概率
switch query.Probability { switch query.Probability {
@ -355,19 +364,20 @@ func (m *YxCalculationMajorMapper) FindRecommendList1(query dto.SchoolMajorQuery
var wg sync.WaitGroup var wg sync.WaitGroup
var countErr, queryErr error var countErr, queryErr error
db := common.CurrentDB()
wg.Add(2) wg.Add(2)
// 协程1COUNT 查询 // 协程1COUNT 查询
go func() { go func() {
defer wg.Done() defer wg.Done()
countErr = config.DB.Raw(countSQL, params...).Count(&total).Error countErr = db.Raw(countSQL, params...).Count(&total).Error
}() }()
// 协程2主查询 // 协程2主查询
go func() { go func() {
defer wg.Done() defer wg.Done()
sql += fmt.Sprintf(" LIMIT %d OFFSET %d", query.Size, (query.Page-1)*query.Size) sql += fmt.Sprintf(" LIMIT %d OFFSET %d", query.Size, (query.Page-1)*query.Size)
queryErr = config.DB.Raw(sql, params...).Scan(&items).Error queryErr = db.Raw(sql, params...).Scan(&items).Error
}() }()
wg.Wait() wg.Wait()
if countErr != nil || queryErr != nil { if countErr != nil || queryErr != nil {
@ -405,11 +415,14 @@ func (m *YxCalculationMajorMapper) FindListByCompositeKeys(tableName string, key
db = db.Table(tableName) db = db.Table(tableName)
} }
sql := "SELECT * FROM " + tableName + " WHERE score_id = ? AND (school_code, major_code, enrollment_code) IN (" sql := "SELECT * FROM " + tableName + " WHERE score_id = ?"
var params []interface{} var params []interface{}
// 将 score_id 作为第一个参数
params = append(params, scoreId) params = append(params, scoreId)
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) {
sql += " AND tenant_id = ?"
params = append(params, tenantID)
}
sql += " AND (school_code, major_code, enrollment_code) IN ("
for i, key := range keys { for i, key := range keys {
parts := strings.Split(key, "_") parts := strings.Split(key, "_")
@ -478,11 +491,16 @@ func (m *YxCalculationMajorMapper) FindDtoListByCompositeKeys(tableName string,
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
LEFT JOIN yx_school s ON s.id = sc.school_id LEFT JOIN yx_school s ON s.id = sc.school_id
WHERE cm.score_id = ? AND (cm.school_code, cm.major_code, cm.enrollment_code) IN ( WHERE cm.score_id = ?
`, tableName) `, tableName)
var params []interface{} var params []interface{}
params = append(params, scoreId) params = append(params, scoreId)
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) {
sqlStr += " AND cm.tenant_id = ?"
params = append(params, tenantID)
}
sqlStr += " AND (cm.school_code, cm.major_code, cm.enrollment_code) IN ("
// Build IN clause // Build IN clause
var tuples []string var tuples []string

View File

@ -1,7 +1,7 @@
package mapper package mapper
import ( import (
"server/config" "server/common"
"server/modules/yx/entity" "server/modules/yx/entity"
) )
@ -25,7 +25,7 @@ func (m *YxHistoryScoreControlLineMapper) SelectByYearAndCategory(year, professi
categories = []string{professionalCategory} categories = []string{professionalCategory}
} }
err := config.DB.Model(&entity.YxHistoryScoreControlLine{}). err := common.CurrentDB().Model(&entity.YxHistoryScoreControlLine{}).
Where("year = ?", year). Where("year = ?", year).
Where("professional_category IN ?", categories). Where("professional_category IN ?", categories).
Where("category = ?", category). Where("category = ?", category).

View File

@ -3,7 +3,6 @@ package service
import ( import (
"server/common" "server/common"
"server/config"
"server/modules/yx/dto" "server/modules/yx/dto"
"server/modules/yx/entity" "server/modules/yx/entity"
"server/modules/yx/mapper" "server/modules/yx/mapper"
@ -109,8 +108,12 @@ func (s *YxHistoryMajorEnrollService) ListBySchoolCodesAndMajorNames(
sql += " AND hme.year IN ?" sql += " AND hme.year IN ?"
params = append(params, years) params = append(params, years)
} }
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable("yx_history_major_enroll") {
sql += " AND hme.tenant_id = ?"
params = append(params, tenantID)
}
sql += " ORDER BY hme.year DESC" sql += " ORDER BY hme.year DESC"
err := config.DB.Raw(sql, params...).Scan(&items).Error err := common.CurrentDB().Raw(sql, params...).Scan(&items).Error
return items, err return items, err
} }