feat: 修复部分功能

This commit is contained in:
zwt13703 2026-03-23 16:44:07 +08:00
parent f7cb916e01
commit 009a290135
28 changed files with 593 additions and 63 deletions

View File

@ -40,7 +40,7 @@ CREATE TABLE t_user (
status TINYINT DEFAULT 1 COMMENT '状态0-禁用1-正常',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
deleted TINYINT DEFAULT 0 COMMENT '软删除0-未删1-已删'
delFlag TINYINT DEFAULT 0 COMMENT '软删除0-未删1-已删'
);
COMMENT ON TABLE t_user IS '用户基础信息表';
@ -56,10 +56,10 @@ CREATE TABLE t_platform_user (
last_login_time TIMESTAMP COMMENT '最后登录时间',
create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
deleted TINYINT DEFAULT 0,
DelFlag TINYINT DEFAULT 0,
-- 联合唯一索引同一平台的openid不能重复
UNIQUE (platform_type, platform_openid),
-- 外键关联用户表
CONSTRAINT fk_platform_user_user_id FOREIGN KEY (user_id) REFERENCES t_user(id) ON DELETE CASCADE
);
COMMENT ON TABLE t_platform_user IS '平台用户关联表(微信/抖音小程序用户信息)';

View File

@ -0,0 +1,42 @@
# 任务执行摘要
## 会话 ID: 20260322-tenant
- [2026-03-22 14:45:56]
- **执行原因**: 增加租户ID多租户查询支持支持从Header读取并按配置过滤。
- **执行过程**:
1. 新增租户配置与中间件,从请求头读取 X-tenantId 并绑定上下文。
2. 引入 GORM 多租户插件与工具函数,支持按表名单自动加 tenant_id 条件、手动覆盖/跳过。
3. 调整现有查询代码与部分 Raw SQL使用带上下文的 DB 并按需拼接租户条件。
- **执行结果**: 已实现可配置的租户过滤与临时覆盖能力,配置已写入配置文件。
# 任务执行摘要
## 会话 ID: 20260322-delflag
- [2026-03-22 15:00:37]
- **执行原因**: 解决查询中使用 delFlag 导致列不存在的错误。
- **执行过程**:
1. 排查 user/platform_user mapper 中软删除条件与更新字段。
2. 将查询与更新字段统一改为数据库字段名 del_flag。
- **执行结果**: 查询条件与软删除更新已改为 del_flag避免列不存在错误。
# 任务执行摘要
## 会话 ID: 20260322-tenant-login
- [2026-03-22 15:11:07]
- **执行原因**: 登录接口要求从 Header 获取 tenantId 并校验。
- **执行过程**:
1. 在登录接口中读取配置的 HeaderKey默认 X-tenantId
2. tenantId 为空时直接返回 400。
3. Swagger 注解增加 Header 参数。
- **执行结果**: 登录接口已强制要求 tenantId。
# 任务执行摘要
## 会话 ID: 20260322-tenant-helper
- [2026-03-22 15:13:58]
- **执行原因**: 抽取通用方法获取租户Header与tenantId。
- **执行过程**:
1. 新增 TenantHeaderKey 与 TenantIDFromHeader 工具方法。
2. 中间件与登录接口改为调用通用方法。
- **执行结果**: tenantId 获取逻辑已统一复用。

View File

@ -0,0 +1,20 @@
# 任务执行摘要
## 会话 ID: 20260323-unused-platformUserID
- [2026-03-23 10:45:49]
- **执行原因**: 修复 wechat_service.go 中未使用变量导致的编译失败。
- **执行过程**:
1. 定位 platformUserID 未被使用的定义与赋值。
2. 移除该变量及相关赋值。
- **执行结果**: 编译错误已排除。
# 任务执行摘要
## 会话 ID: 20260323-tenant-error-const
- [2026-03-23 11:10:36]
- **执行原因**: 抽取 tenantId 缺失的错误文案为常量。
- **执行过程**:
1. 在 common/constants.go 新增 ErrTenantIDMissing 常量。
2. 登录接口引用该常量返回错误。
- **执行结果**: 错误文案已集中管理并复用。

View File

@ -25,6 +25,9 @@ func NewBaseMapper[T any]() *BaseMapper[T] {
// GetDB 获取数据库实例(允许子类覆盖)
func (m *BaseMapper[T]) GetDB() *gorm.DB {
if db := CurrentDB(); db != nil {
return db
}
return m.db
}

View File

@ -26,6 +26,12 @@ const (
HeaderTokenPrefix = "Bearer "
)
// 错误信息常量
const (
ErrTenantIDMissing = "未设置租户Id"
ErrPhonePwdMissing = "手机号和密码不能为空"
)
// 业务状态常量
const (
StateActive = "1" // 使用中

35
server/common/db.go Normal file
View File

@ -0,0 +1,35 @@
package common
import (
"server/config"
"gorm.io/gorm"
)
// CurrentDB 返回带请求上下文的 DB 实例
func CurrentDB() *gorm.DB {
if config.DB == nil {
return nil
}
ctx := GetRequestContext()
if ctx == nil {
return config.DB
}
return config.DB.WithContext(ctx)
}
// WithTenantID 临时覆盖租户ID
func WithTenantID(db *gorm.DB, tenantID string) *gorm.DB {
if db == nil {
return db
}
return db.Set(DBSettingTenantID, tenantID)
}
// SkipTenant 临时跳过租户过滤
func SkipTenant(db *gorm.DB) *gorm.DB {
if db == nil {
return db
}
return db.Set(DBSettingTenantSkip, true)
}

View File

@ -0,0 +1,94 @@
package common
import (
"context"
"fmt"
"runtime"
"strconv"
"strings"
"sync"
)
type contextKey string
const (
// ContextTenantIDKey 用于在 Context 中存储租户ID
ContextTenantIDKey contextKey = "tenantId"
// DBSettingTenantID 用于在 GORM DB 上临时覆盖租户ID
DBSettingTenantID = "tenant_id"
// DBSettingTenantSkip 用于临时跳过租户过滤
DBSettingTenantSkip = "tenant_skip"
)
var requestContextStore sync.Map
// BindRequestContext 绑定当前请求上下文到当前协程,返回清理函数
func BindRequestContext(ctx context.Context) func() {
gid := goroutineID()
if gid == 0 {
return func() {}
}
requestContextStore.Store(gid, ctx)
return func() {
requestContextStore.Delete(gid)
}
}
// GetRequestContext 获取当前协程绑定的请求上下文
func GetRequestContext() context.Context {
gid := goroutineID()
if gid == 0 {
return nil
}
if v, ok := requestContextStore.Load(gid); ok {
if ctx, ok := v.(context.Context); ok {
return ctx
}
}
return nil
}
// TenantIDFromContext 从 Context 中获取租户ID
func TenantIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
value := ctx.Value(ContextTenantIDKey)
if value == nil {
return ""
}
return strings.TrimSpace(toString(value))
}
// CurrentTenantID 获取当前协程绑定的租户ID
func CurrentTenantID() string {
return TenantIDFromContext(GetRequestContext())
}
func goroutineID() uint64 {
var buf [64]byte
n := runtime.Stack(buf[:], false)
if n <= 0 {
return 0
}
fields := strings.Fields(string(buf[:n]))
if len(fields) < 2 {
return 0
}
id, err := strconv.ParseUint(fields[1], 10, 64)
if err != nil {
return 0
}
return id
}
func toString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprint(value)
}
}

View File

@ -0,0 +1,11 @@
package common
import "strings"
// TenantIDFromHeader 从请求头获取租户ID
func TenantIDFromHeader(getHeader func(string) string) string {
if getHeader == nil {
return ""
}
return strings.TrimSpace(getHeader(TenantHeaderKey()))
}

View File

@ -0,0 +1,184 @@
package common
import (
"fmt"
"strings"
"server/config"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
// InitTenantPlugin 初始化租户插件
func InitTenantPlugin(db *gorm.DB) {
if db == nil {
return
}
plugin := NewTenantPlugin(config.AppConfig.Tenant)
if err := db.Use(plugin); err != nil {
LogError("租户插件初始化失败: %v", err)
return
}
Info("租户插件初始化完成 (enable=%v)", plugin.enabled)
}
// TenantPlugin GORM 多租户插件
type TenantPlugin struct {
enabled bool
column string
tableSet map[string]struct{}
applyAll bool
}
// NewTenantPlugin 创建插件实例
func NewTenantPlugin(cfg config.TenantConfig) *TenantPlugin {
column := strings.TrimSpace(cfg.Column)
if column == "" {
column = "tenant_id"
}
tableSet := make(map[string]struct{})
applyAll := false
for _, table := range cfg.Tables {
name := normalizeTableName(table)
if name == "" {
continue
}
if name == "*" {
applyAll = true
continue
}
tableSet[name] = struct{}{}
}
return &TenantPlugin{
enabled: cfg.Enable,
column: column,
tableSet: tableSet,
applyAll: applyAll,
}
}
// Name 插件名
func (p *TenantPlugin) Name() string {
return "tenant_plugin"
}
// Initialize 注册回调
func (p *TenantPlugin) Initialize(db *gorm.DB) error {
db.Callback().Query().Before("gorm:query").Register("tenant:query", p.before)
db.Callback().Row().Before("gorm:row").Register("tenant:row", p.before)
db.Callback().Raw().Before("gorm:raw").Register("tenant:raw", p.before)
db.Callback().Update().Before("gorm:update").Register("tenant:update", p.before)
db.Callback().Delete().Before("gorm:delete").Register("tenant:delete", p.before)
return nil
}
func (p *TenantPlugin) before(db *gorm.DB) {
if !p.enabled || db == nil || db.Statement == nil {
return
}
if skip, ok := db.Get(DBSettingTenantSkip); ok {
if v, ok := skip.(bool); ok && v {
return
}
}
tenantID, ok := p.tenantID(db)
if !ok {
return
}
table := p.tableName(db)
if table == "" || !p.matchTable(table) {
return
}
db.Statement.AddClause(clause.Where{
Exprs: []clause.Expression{
clause.Eq{
Column: clause.Column{Name: p.column},
Value: tenantID,
},
},
})
}
func (p *TenantPlugin) tenantID(db *gorm.DB) (string, bool) {
if db == nil {
return "", false
}
if value, ok := db.Get(DBSettingTenantID); ok {
id := strings.TrimSpace(fmt.Sprint(value))
if id != "" {
return id, true
}
}
if db.Statement != nil && db.Statement.Context != nil {
id := TenantIDFromContext(db.Statement.Context)
if id != "" {
return id, true
}
}
return "", false
}
func (p *TenantPlugin) tableName(db *gorm.DB) string {
if db == nil || db.Statement == nil {
return ""
}
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
return normalizeTableName(db.Statement.Schema.Table)
}
if db.Statement.Table != "" {
return normalizeTableName(db.Statement.Table)
}
if db.Statement.SQL.String() != "" {
return normalizeTableName(tableFromSQL(db.Statement.SQL.String()))
}
return ""
}
func (p *TenantPlugin) matchTable(table string) bool {
if p.applyAll {
return true
}
_, ok := p.tableSet[normalizeTableName(table)]
return ok
}
func normalizeTableName(name string) string {
name = strings.TrimSpace(name)
if name == "" {
return ""
}
fields := strings.Fields(name)
if len(fields) > 0 {
name = fields[0]
}
name = strings.Trim(name, "`\"")
if idx := strings.LastIndex(name, "."); idx >= 0 {
name = name[idx+1:]
}
return strings.ToLower(strings.TrimSpace(name))
}
func tableFromSQL(sql string) string {
lower := strings.ToLower(sql)
idx := strings.Index(lower, " from ")
if idx == -1 {
return ""
}
rest := strings.TrimSpace(sql[idx+6:])
if rest == "" {
return ""
}
fields := strings.Fields(rest)
if len(fields) == 0 {
return ""
}
table := strings.Trim(fields[0], "`\"")
if strings.HasPrefix(table, "(") {
return ""
}
return table
}

View File

@ -0,0 +1,41 @@
package common
import (
"strings"
"server/config"
)
// TenantHeaderKey 获取租户Header字段名
func TenantHeaderKey() string {
headerKey := strings.TrimSpace(config.AppConfig.Tenant.HeaderKey)
if headerKey == "" {
return "X-Tenant-Id"
}
return headerKey
}
// IsTenantTable 判断表是否启用租户过滤
func IsTenantTable(table string) bool {
cfg := config.AppConfig.Tenant
if !cfg.Enable {
return false
}
name := normalizeTableName(table)
if name == "" {
return false
}
for _, item := range cfg.Tables {
tableName := normalizeTableName(item)
if tableName == "" {
continue
}
if tableName == "*" {
return true
}
if strings.EqualFold(tableName, name) {
return true
}
}
return false
}

View File

@ -13,6 +13,14 @@ security:
header_key: X-App-Sign
secret_key: yts@2025#secure
tenant:
enable: true
header_key: X-Tenant-Id
column: tenant_id
tables:
- t_user
- t_platform_user
payload_crypto:
enable: true
header_key: X-App-Encrypt

View File

@ -22,6 +22,7 @@ type appConfig struct {
Redis RedisConfig `yaml:"redis"`
Wechat WechatConfig `yaml:"wechat"`
AppConfig AppVersionConfig `yaml:"app_config"`
Tenant TenantConfig `yaml:"tenant"`
}
// LogConfig 日志配置
@ -124,6 +125,14 @@ type AppVersionConfig struct {
TenantId string `yaml:"tenantId"`
}
// TenantConfig 多租户配置
type TenantConfig struct {
Enable bool `yaml:"enable"` // 是否启用
HeaderKey string `yaml:"header_key"` // Header 中的租户ID字段名
Column string `yaml:"column"` // 表中的租户字段名
Tables []string `yaml:"tables"` // 启用租户过滤的表
}
// AppClientConfig 客户端版本配置
type AppClientConfig struct {
MinVersion string `yaml:"min_version"`

View File

@ -13,17 +13,25 @@ security:
header_key: X-App-Sign
secret_key: yts@2025#secure
tenant:
enable: true
header_key: X-Tenant-Id
column: tenant_id
tables:
- t_user
- t_platform_user
payload_crypto:
enable: false
enable: true
header_key: X-App-Encrypt
secret_key: ""
secret_key: "1"
whitelist:
- /swagger/
request:
enable: false
required: false
response:
enable: false
enable: true
required: false
rate_limit:

View File

@ -13,6 +13,14 @@ security:
header_key: X-App-Sign
secret_key: yts@2025#secure
tenant:
enable: false
header_key: X-tenantId
column: tenant_id
tables:
- yx_user_score
- yx_volunteer
payload_crypto:
enable: false
header_key: X-App-Encrypt

View File

@ -60,6 +60,7 @@ func main() {
// 初始化数据库
config.InitDB()
common.Info("数据库初始化完成")
common.InitTenantPlugin(config.DB)
// 初始化Redis
config.InitRedis()
@ -86,6 +87,7 @@ func main() {
api := r.Group("/api")
// 中间件顺序: 参数加解密 -> 安全校验 -> 限流 -> 登录鉴权
api.Use(middleware.TenantMiddleware())
api.Use(middleware.PayloadCryptoMiddleware())
api.Use(middleware.SecurityMiddleware())
api.Use(middleware.RateLimitMiddleware())

View File

@ -0,0 +1,29 @@
package middleware
import (
"context"
"server/common"
"github.com/gin-gonic/gin"
)
// TenantMiddleware 解析租户ID并绑定到请求上下文
func TenantMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
tenantID := common.TenantIDFromHeader(c.GetHeader)
ctx := c.Request.Context()
if tenantID != "" {
ctx = context.WithValue(ctx, common.ContextTenantIDKey, tenantID)
}
cleanup := common.BindRequestContext(ctx)
defer cleanup()
if ctx != c.Request.Context() {
c.Request = c.Request.WithContext(ctx)
}
c.Next()
}
}

View File

@ -27,13 +27,20 @@ func (ctrl *OpenAuthController) RegisterRoutes(r *gin.RouterGroup) {
// @Tags 对外接口
// @Accept json
// @Produce json
// @Param X-tenantId header string true "租户ID"
// @Param request body dto.UserPasswordLoginRequest true "登录信息"
// @Success 200 {object} common.Response
// @Router /open/user/login [post]
func (ctrl *OpenAuthController) LoginByPhone(c *gin.Context) {
tenantID := common.TenantIDFromHeader(c.GetHeader)
if tenantID == "" {
common.Error(c, 400, common.ErrTenantIDMissing)
return
}
var req apiDto.UserPasswordLoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
common.Error(c, 400, "手机号和密码不能为空")
common.Error(c, 400, common.ErrPhonePwdMissing)
return
}
loginUser, token, err := ctrl.userService.LoginByPhonePassword(req.Phone, req.Password)

View File

@ -32,6 +32,11 @@ func (ctrl *WechatMiniProgramController) RegisterRoutes(r *gin.RouterGroup) {
// @Router /open/wechat/mini/login [post]
func (ctrl *WechatMiniProgramController) MiniLogin(c *gin.Context) {
var req apiDto.WechatMiniLoginRequest
tenantID := common.TenantIDFromHeader(c.GetHeader)
if tenantID == "" {
common.Error(c, 400, common.ErrTenantIDMissing)
return
}
if err := c.ShouldBindJSON(&req); err != nil {
common.Error(c, 400, "参数错误")
return

View File

@ -20,6 +20,7 @@ type AppClientConfig struct {
// AppEndpointConfig 接口或 WebView 配置
type AppEndpointConfig struct {
TenantId string `json:"tenantId"`
BaseURL string `json:"baseUrl"`
Version string `json:"version"`
MinClientVersion string `json:"minClientVersion"`

View File

@ -19,11 +19,11 @@ type WechatMiniLoginRequest struct {
// WechatMiniLoginResponse 微信小程序登录响应
type WechatMiniLoginResponse struct {
UserID int64 `json:"userId"`
PlatformUserID int64 `json:"platformUserId"`
OpenID string `json:"openid"`
UnionID string `json:"unionid"`
SessionKey string `json:"sessionKey"`
Phone string `json:"phone"`
// PlatformUserID int64 `json:"platformUserId"`
// OpenID string `json:"openid"`
// UnionID string `json:"unionid"`
// SessionKey string `json:"sessionKey"`
// Phone string `json:"phone"`
Token string `json:"token"`
IsNewPlatform bool `json:"isNewPlatform"`
IsNewUser bool `json:"isNewUser"`

View File

@ -22,6 +22,7 @@ func (s *AppConfigService) GetConfig() apiDto.AppConfigResponse {
ForceUpdate: cfg.App.ForceUpdate,
},
API: apiDto.AppEndpointConfig{
TenantId: cfg.TenantId,
BaseURL: cfg.API.BaseURL,
Version: cfg.API.Version,
MinClientVersion: cfg.API.MinClientVersion,

View File

@ -62,7 +62,6 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a
isNewPlatform := false
isNewUser := false
var userID int64
var platformUserID int64
var phone string
if req.PhoneCode != "" {
@ -132,13 +131,11 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a
}
userID = user.ID
platformUserID = platform.ID
} else {
return nil, err
}
} else {
userID = platformUser.UserID
platformUserID = platformUser.ID
fields := map[string]interface{}{
"platform_session_key": session.SessionKey,
"last_login_time": now,
@ -179,11 +176,11 @@ func (s *WechatMiniProgramService) Login(req *apiDto.WechatMiniLoginRequest) (*a
return &apiDto.WechatMiniLoginResponse{
UserID: userID,
PlatformUserID: platformUserID,
OpenID: session.OpenID,
UnionID: session.UnionID,
SessionKey: session.SessionKey,
Phone: phone,
// PlatformUserID: platformUserID,
// OpenID: session.OpenID,
// UnionID: session.UnionID,
// SessionKey: session.SessionKey,
// Phone: phone,
Token: token,
IsNewPlatform: isNewPlatform,
IsNewUser: isNewUser,

View File

@ -2,27 +2,26 @@
package mapper
import (
"server/config"
"server/common"
"server/modules/user/entity"
"gorm.io/gorm"
)
type PlatformUserMapper struct {
db *gorm.DB
}
func NewPlatformUserMapper() *PlatformUserMapper {
return &PlatformUserMapper{db: config.DB}
return &PlatformUserMapper{}
}
func (m *PlatformUserMapper) baseDB() *gorm.DB {
return m.db
return common.CurrentDB()
}
// GetDB 获取数据库实例,默认过滤软删除
func (m *PlatformUserMapper) GetDB() *gorm.DB {
return m.baseDB().Where("delFlag = 0")
return m.baseDB().Where("del_flag = 0")
}
// FindAll 分页查询
@ -73,5 +72,5 @@ func (m *PlatformUserMapper) UpdateFields(id int64, fields map[string]interface{
// Delete 逻辑删除
func (m *PlatformUserMapper) Delete(id int64) error {
return m.baseDB().Model(&entity.PlatformUser{}).Where("id = ?", id).Update("delFlag", 1).Error
return m.baseDB().Model(&entity.PlatformUser{}).Where("id = ?", id).Update("del_flag", 1).Error
}

View File

@ -2,27 +2,26 @@
package mapper
import (
"server/config"
"server/common"
"server/modules/user/entity"
"gorm.io/gorm"
)
type UserMapper struct {
db *gorm.DB
}
func NewUserMapper() *UserMapper {
return &UserMapper{db: config.DB}
return &UserMapper{}
}
func (m *UserMapper) baseDB() *gorm.DB {
return m.db
return common.CurrentDB()
}
// GetDB 获取数据库实例,默认过滤软删除
func (m *UserMapper) GetDB() *gorm.DB {
return m.baseDB().Where("delFlag = 0")
return m.baseDB().Where("del_flag = 0")
}
// FindAll 分页查询
@ -66,5 +65,5 @@ func (m *UserMapper) UpdateFields(id int64, fields map[string]interface{}) error
// Delete 逻辑删除
func (m *UserMapper) Delete(id int64) error {
return m.baseDB().Model(&entity.User{}).Where("id = ?", id).Update("delFlag", 1).Error
return m.baseDB().Model(&entity.User{}).Where("id = ?", id).Update("del_flag", 1).Error
}

View File

@ -33,7 +33,7 @@ type UserScoreService struct {
func (s *UserScoreService) GetActiveScoreID(userID string) (string, error) {
var score entity.YxUserScore
// 明确指定字段,提高可读性
err := config.DB.Model(&entity.YxUserScore{}).
err := common.CurrentDB().Model(&entity.YxUserScore{}).
Where("create_by = ? AND state = ?", userID, "1").
Select("id").
First(&score).Error
@ -50,7 +50,7 @@ func (s *UserScoreService) GetActiveScoreID(userID string) (string, error) {
func (s *UserScoreService) GetActiveScoreByID(userID string) (entity.YxUserScore, error) {
var score entity.YxUserScore
// 明确指定字段,提高可读性
err := config.DB.Model(&entity.YxUserScore{}).
err := common.CurrentDB().Model(&entity.YxUserScore{}).
Where("create_by = ? AND state = ?", userID, "1").
First(&score).Error
if err != nil {
@ -68,7 +68,7 @@ func (s *UserScoreService) GetActiveScoreByUserID(userID string) (vo.UserScoreVO
scoreRedisData, err := config.RDB.Get(context.Background(), common.RedisUserScorePrefix+userID).Result()
if err != nil {
// 明确指定字段,提高可读性
err := config.DB.Model(&entity.YxUserScore{}).
err := common.CurrentDB().Model(&entity.YxUserScore{}).
Where("create_by = ? AND state = ?", userID, "1").
First(&score).Error
if err != nil {
@ -99,7 +99,7 @@ func (s *UserScoreService) GetActiveScoreByUserID(userID string) (vo.UserScoreVO
func (s *UserScoreService) GetByID(id string) (vo.UserScoreVO, error) {
var score entity.YxUserScore
err := config.DB.Model(&entity.YxUserScore{}).
err := common.CurrentDB().Model(&entity.YxUserScore{}).
Where("id = ?", id).
First(&score).Error
if err != nil {
@ -116,7 +116,7 @@ func (s *UserScoreService) ListByUser(userID string, page, size int) ([]vo.UserS
var scores []entity.YxUserScore
var total int64
query := config.DB.Model(&entity.YxUserScore{}).Where("create_by = ?", userID)
query := common.CurrentDB().Model(&entity.YxUserScore{}).Where("create_by = ?", userID)
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("查询总数失败: %w", err)
@ -360,7 +360,7 @@ func (s *UserScoreService) SaveUserScore(req *yxDto.SaveScoreRequest) (vo.UserSc
entityItem.UpdateTime = time.Now()
// 3. 执行保存操作(可以包含事务)
tx := config.DB.Begin()
tx := common.CurrentDB().Begin()
defer func() {
if r := recover(); r != nil {
fmt.Printf("【PANIC】事务执行过程中发生panic: %v", r)

View File

@ -4,7 +4,6 @@ package mapper
import (
"fmt"
"server/common"
"server/config"
"server/modules/yx/dto"
"server/modules/yx/entity"
"strings"
@ -85,6 +84,10 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
baseSQL += " AND cm.major_name like ?"
params = append(params, query.Keyword)
}
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) {
baseSQL += " AND cm.tenant_id = ?"
params = append(params, tenantID)
}
// 3. 优化后的总数量COUNT SQL
countSQL := fmt.Sprintf(`
@ -178,6 +181,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
// 整体开始时间
totalStartTime := time.Now()
db := common.CurrentDB()
wg.Add(3)
@ -186,7 +190,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
defer wg.Done()
// 记录该协程单独的开始时间
start := time.Now()
countErr = config.DB.Raw(countSQL, params...).Count(&total).Error
countErr = db.Raw(countSQL, params...).Count(&total).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock()
queryCost.CountCost = time.Since(start)
@ -198,7 +202,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
defer wg.Done()
// 记录该协程单独的开始时间
start := time.Now()
probCountErr = config.DB.Raw(probCountSQL, params...).Scan(&probCount).Error
probCountErr = db.Raw(probCountSQL, params...).Scan(&probCount).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock()
queryCost.ProbCountCost = time.Since(start)
@ -210,7 +214,7 @@ func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery)
defer wg.Done()
// 记录该协程单独的开始时间
start := time.Now()
queryErr = config.DB.Raw(mainSQL, params...).Scan(&items).Error
queryErr = db.Raw(mainSQL, params...).Scan(&items).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock()
queryCost.QueryCost = time.Since(start)
@ -336,6 +340,11 @@ func (m *YxCalculationMajorMapper) FindRecommendList1(query dto.SchoolMajorQuery
sql += " AND cm.main_subjects = ?"
params = append(params, query.MainSubjects)
}
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) {
countSQL += " AND cm.tenant_id = ?"
sql += " AND cm.tenant_id = ?"
params = append(params, tenantID)
}
// 录取概率
switch query.Probability {
@ -355,19 +364,20 @@ func (m *YxCalculationMajorMapper) FindRecommendList1(query dto.SchoolMajorQuery
var wg sync.WaitGroup
var countErr, queryErr error
db := common.CurrentDB()
wg.Add(2)
// 协程1COUNT 查询
go func() {
defer wg.Done()
countErr = config.DB.Raw(countSQL, params...).Count(&total).Error
countErr = db.Raw(countSQL, params...).Count(&total).Error
}()
// 协程2主查询
go func() {
defer wg.Done()
sql += fmt.Sprintf(" LIMIT %d OFFSET %d", query.Size, (query.Page-1)*query.Size)
queryErr = config.DB.Raw(sql, params...).Scan(&items).Error
queryErr = db.Raw(sql, params...).Scan(&items).Error
}()
wg.Wait()
if countErr != nil || queryErr != nil {
@ -405,11 +415,14 @@ func (m *YxCalculationMajorMapper) FindListByCompositeKeys(tableName string, key
db = db.Table(tableName)
}
sql := "SELECT * FROM " + tableName + " WHERE score_id = ? AND (school_code, major_code, enrollment_code) IN ("
sql := "SELECT * FROM " + tableName + " WHERE score_id = ?"
var params []interface{}
// 将 score_id 作为第一个参数
params = append(params, scoreId)
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) {
sql += " AND tenant_id = ?"
params = append(params, tenantID)
}
sql += " AND (school_code, major_code, enrollment_code) IN ("
for i, key := range keys {
parts := strings.Split(key, "_")
@ -478,11 +491,16 @@ func (m *YxCalculationMajorMapper) FindDtoListByCompositeKeys(tableName string,
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
LEFT JOIN yx_school s ON s.id = sc.school_id
WHERE cm.score_id = ? AND (cm.school_code, cm.major_code, cm.enrollment_code) IN (
WHERE cm.score_id = ?
`, tableName)
var params []interface{}
params = append(params, scoreId)
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable(tableName) {
sqlStr += " AND cm.tenant_id = ?"
params = append(params, tenantID)
}
sqlStr += " AND (cm.school_code, cm.major_code, cm.enrollment_code) IN ("
// Build IN clause
var tuples []string

View File

@ -1,7 +1,7 @@
package mapper
import (
"server/config"
"server/common"
"server/modules/yx/entity"
)
@ -25,7 +25,7 @@ func (m *YxHistoryScoreControlLineMapper) SelectByYearAndCategory(year, professi
categories = []string{professionalCategory}
}
err := config.DB.Model(&entity.YxHistoryScoreControlLine{}).
err := common.CurrentDB().Model(&entity.YxHistoryScoreControlLine{}).
Where("year = ?", year).
Where("professional_category IN ?", categories).
Where("category = ?", category).

View File

@ -3,7 +3,6 @@ package service
import (
"server/common"
"server/config"
"server/modules/yx/dto"
"server/modules/yx/entity"
"server/modules/yx/mapper"
@ -109,8 +108,12 @@ func (s *YxHistoryMajorEnrollService) ListBySchoolCodesAndMajorNames(
sql += " AND hme.year IN ?"
params = append(params, years)
}
if tenantID := common.CurrentTenantID(); tenantID != "" && common.IsTenantTable("yx_history_major_enroll") {
sql += " AND hme.tenant_id = ?"
params = append(params, tenantID)
}
sql += " ORDER BY hme.year DESC"
err := config.DB.Raw(sql, params...).Scan(&items).Error
err := common.CurrentDB().Raw(sql, params...).Scan(&items).Error
return items, err
}