424 lines
13 KiB
Go
424 lines
13 KiB
Go
// Package mapper 数据访问层
|
||
package mapper
|
||
|
||
import (
|
||
"fmt"
|
||
"server/config"
|
||
"server/modules/yx/dto"
|
||
"server/modules/yx/entity"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"gorm.io/gorm/clause"
|
||
)
|
||
|
||
type YxCalculationMajorMapper struct{}
|
||
|
||
func NewYxCalculationMajorMapper() *YxCalculationMajorMapper {
|
||
return &YxCalculationMajorMapper{}
|
||
}
|
||
|
||
// 先定义存储各协程耗时的结构体(局部使用,也可全局复用)
|
||
type QueryCostTime struct {
|
||
CountCost time.Duration // 总数量查询耗时
|
||
ProbCountCost time.Duration // 四种概率数量查询耗时
|
||
QueryCost time.Duration // 主列表查询耗时
|
||
TotalCost time.Duration // 整体总耗时
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) FindAll(page, size int) ([]entity.YxCalculationMajor, int64, error) {
|
||
var items []entity.YxCalculationMajor
|
||
var total int64
|
||
config.DB.Model(&entity.YxCalculationMajor{}).Count(&total)
|
||
err := config.DB.Offset((page - 1) * size).Limit(size).Find(&items).Error
|
||
return items, total, err
|
||
}
|
||
|
||
// 调整返回值:新增 ProbabilityCountDTO,返回列表、总数量、四种概率各自数量
|
||
func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery) ([]dto.UserMajorDTO, int64, dto.ProbabilityCountDTO, error) {
|
||
var items []dto.UserMajorDTO
|
||
var total int64
|
||
var probCount dto.ProbabilityCountDTO // 四种概率的数量统计结果
|
||
|
||
// 1. 表名合法性校验:非空 + 白名单
|
||
tableName := query.UserScoreVO.CalculationTableName
|
||
if tableName == "" {
|
||
return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("CalculationTableName is empty")
|
||
}
|
||
// if !validTableNames[] {
|
||
// return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("invalid table name: %s, potential SQL injection risk", tableName)
|
||
// }
|
||
|
||
// 2. 基础条件SQL(共用过滤条件,排除概率筛选)
|
||
baseSQL := " WHERE 1=1 AND cm.state > 0 "
|
||
params := []interface{}{}
|
||
|
||
// 拼接共用过滤条件(与原有列表查询条件一致,保证统计结果准确性)
|
||
if query.UserScoreVO.ID != "" {
|
||
baseSQL += " AND cm.score_id = ?"
|
||
params = append(params, query.UserScoreVO.ID)
|
||
}
|
||
if query.MajorType != "" {
|
||
baseSQL += " AND cm.major_type = ?"
|
||
params = append(params, query.MajorType)
|
||
}
|
||
if query.Category != "" {
|
||
baseSQL += " AND cm.category = ?"
|
||
params = append(params, query.Category)
|
||
}
|
||
if len(query.MajorTypeChildren) > 0 {
|
||
placeholders := strings.Repeat("?,", len(query.MajorTypeChildren)-1) + "?"
|
||
baseSQL += " AND cm.major_type_child IN (" + placeholders + ")"
|
||
for _, v := range query.MajorTypeChildren {
|
||
params = append(params, v)
|
||
}
|
||
}
|
||
if query.MainSubjects != "" {
|
||
baseSQL += " AND cm.main_subjects = ?"
|
||
params = append(params, query.MainSubjects)
|
||
}
|
||
|
||
// 3. 优化后的总数量COUNT SQL
|
||
countSQL := fmt.Sprintf(`
|
||
SELECT COUNT(cm.id) FROM %s cm
|
||
%s
|
||
`, tableName, baseSQL)
|
||
|
||
// 4. 四种概率批量统计SQL(使用CASE WHEN一次查询,性能最优)
|
||
probCountSQL := fmt.Sprintf(`
|
||
SELECT
|
||
SUM(CASE WHEN cm.enroll_probability < 60 THEN 1 ELSE 0 END) AS hard_admit,
|
||
SUM(CASE WHEN cm.enroll_probability >= 60 AND cm.enroll_probability < 73 THEN 1 ELSE 0 END) AS impact,
|
||
SUM(CASE WHEN cm.enroll_probability >= 73 AND cm.enroll_probability < 93 THEN 1 ELSE 0 END) AS stable,
|
||
SUM(CASE WHEN cm.enroll_probability >= 93 THEN 1 ELSE 0 END) AS secure
|
||
FROM %s cm
|
||
%s
|
||
`, tableName, baseSQL)
|
||
|
||
// 5. 主查询SQL(保留原有字段和JOIN)
|
||
mainSQL := fmt.Sprintf(`
|
||
SELECT
|
||
cm.id,
|
||
s.school_name,
|
||
s.school_icon,
|
||
cm.state,
|
||
cm.school_code,
|
||
cm.major_code,
|
||
cm.major_name,
|
||
cm.enrollment_code,
|
||
cm.tuition,
|
||
cm.detail as majorDetail,
|
||
cm.category,
|
||
cm.batch,
|
||
cm.private_student_converted_score as privateStudentScore,
|
||
cm.student_old_converted_score as studentScore,
|
||
cm.student_converted_score,
|
||
cm.enroll_probability,
|
||
cm.rules_enroll_probability_sx,
|
||
cm.rules_enroll_probability,
|
||
cm.probability_operator,
|
||
cm.major_type,
|
||
cm.major_type_child,
|
||
cm.plan_num,
|
||
cm.main_subjects,
|
||
cm.limitation,
|
||
cm.other_score_limitation,
|
||
s.province as province,
|
||
s.school_nature as schoolNature,
|
||
s.institution_type as institutionType
|
||
FROM %s cm
|
||
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
|
||
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
|
||
LEFT JOIN yx_school s ON s.id = sc.school_id
|
||
%s
|
||
`, tableName, baseSQL)
|
||
|
||
// 拼接传入概率的筛选条件(兼容原有业务逻辑)
|
||
switch query.Probability {
|
||
case "难录取":
|
||
mainSQL += " AND cm.enroll_probability < 60"
|
||
case "可冲击":
|
||
mainSQL += " AND (cm.enroll_probability >= 60 and cm.enroll_probability < 73)"
|
||
case "较稳妥":
|
||
mainSQL += " AND (cm.enroll_probability >= 73 and cm.enroll_probability < 93)"
|
||
case "可保底":
|
||
mainSQL += " AND (cm.enroll_probability >= 93)"
|
||
}
|
||
|
||
// 6. 分页参数合法性校验
|
||
page := query.Page
|
||
size := query.Size
|
||
if page < 1 {
|
||
page = 1
|
||
}
|
||
if size < 1 {
|
||
size = 10
|
||
}
|
||
if size > 100 {
|
||
size = 100
|
||
}
|
||
offset := (page - 1) * size
|
||
// 提前拼接分页条件,避免协程内操作共享变量
|
||
mainSQL += fmt.Sprintf(" LIMIT %d OFFSET %d", size, offset)
|
||
|
||
// 7. 协程并发执行三个查询(总数量、概率数量、主列表),提升性能
|
||
// ---------------------- 核心局部代码(替换你原来的协程块) ----------------------
|
||
var wg sync.WaitGroup
|
||
var countErr, probCountErr, queryErr error
|
||
var queryCost QueryCostTime // 存储各协程耗时
|
||
var mu sync.Mutex // 互斥锁:防止多协程同时修改queryCost引发竞态问题
|
||
|
||
// 整体开始时间
|
||
totalStartTime := time.Now()
|
||
|
||
wg.Add(3)
|
||
|
||
// 协程1:总数量查询(单独记录耗时)
|
||
go func() {
|
||
defer wg.Done()
|
||
// 记录该协程单独的开始时间
|
||
start := time.Now()
|
||
countErr = config.DB.Raw(countSQL, params...).Count(&total).Error
|
||
// 计算该协程耗时,通过互斥锁安全写入共享变量
|
||
mu.Lock()
|
||
queryCost.CountCost = time.Now().Sub(start)
|
||
mu.Unlock()
|
||
}()
|
||
|
||
// 协程2:四种概率数量批量查询(单独记录耗时)
|
||
go func() {
|
||
defer wg.Done()
|
||
// 记录该协程单独的开始时间
|
||
start := time.Now()
|
||
probCountErr = config.DB.Raw(probCountSQL, params...).Scan(&probCount).Error
|
||
// 计算该协程耗时,通过互斥锁安全写入共享变量
|
||
mu.Lock()
|
||
queryCost.ProbCountCost = time.Now().Sub(start)
|
||
mu.Unlock()
|
||
}()
|
||
|
||
// 协程3:主列表查询(单独记录耗时)
|
||
go func() {
|
||
defer wg.Done()
|
||
// 记录该协程单独的开始时间
|
||
start := time.Now()
|
||
queryErr = config.DB.Raw(mainSQL, params...).Scan(&items).Error
|
||
// 计算该协程耗时,通过互斥锁安全写入共享变量
|
||
mu.Lock()
|
||
queryCost.QueryCost = time.Now().Sub(start)
|
||
mu.Unlock()
|
||
}()
|
||
|
||
wg.Wait()
|
||
|
||
// 计算整体总耗时
|
||
queryCost.TotalCost = time.Now().Sub(totalStartTime)
|
||
|
||
// 打印各协程耗时和总耗时(按需输出,可注释或删除)
|
||
fmt.Printf("各查询耗时统计:\n")
|
||
fmt.Printf(" 总数量查询耗时:%v\n", queryCost.CountCost)
|
||
fmt.Printf(" 概率数量查询耗时:%v\n", queryCost.ProbCountCost)
|
||
fmt.Printf(" 主列表查询耗时:%v\n", queryCost.QueryCost)
|
||
fmt.Printf(" 整体总耗时:%v\n", queryCost.TotalCost)
|
||
|
||
// 8. 错误处理
|
||
if countErr != nil {
|
||
return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("failed to query total count: %w", countErr)
|
||
}
|
||
if probCountErr != nil {
|
||
return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("failed to query probability count: %w", probCountErr)
|
||
}
|
||
if queryErr != nil {
|
||
return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("failed to query recommend major list: %w", queryErr)
|
||
}
|
||
|
||
return items, total, probCount, nil
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) FindRecommendList1(query dto.SchoolMajorQuery) ([]dto.UserMajorDTO, int64, error) {
|
||
var items []dto.UserMajorDTO
|
||
var total int64
|
||
|
||
// 确保表名存在,防止 SQL 注入或空表名
|
||
tableName := query.UserScoreVO.CalculationTableName
|
||
if tableName == "" {
|
||
return nil, 0, fmt.Errorf("CalculationTableName is empty")
|
||
}
|
||
|
||
// 使用 Sprintf 动态插入表名
|
||
countSQL := fmt.Sprintf(`
|
||
SELECT COUNT(cm.id) FROM %s cm
|
||
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
|
||
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
|
||
LEFT JOIN yx_school s ON s.id = sc.school_id
|
||
WHERE 1=1 AND cm.state > 0
|
||
`, tableName)
|
||
|
||
sql := fmt.Sprintf(`
|
||
SELECT
|
||
cm.id,
|
||
s.school_name,
|
||
s.school_icon,
|
||
cm.state,
|
||
cm.school_code,
|
||
cm.major_code,
|
||
cm.major_name,
|
||
cm.enrollment_code,
|
||
cm.tuition,
|
||
cm.detail as majorDetail,
|
||
cm.category,
|
||
cm.batch,
|
||
cm.private_student_converted_score as privateStudentScore,
|
||
cm.student_old_converted_score as studentScore,
|
||
cm.student_converted_score,
|
||
cm.enroll_probability,
|
||
cm.rules_enroll_probability_sx,
|
||
cm.rules_enroll_probability,
|
||
cm.probability_operator,
|
||
cm.major_type,
|
||
cm.major_type_child,
|
||
cm.plan_num,
|
||
cm.main_subjects,
|
||
cm.limitation,
|
||
cm.other_score_limitation,
|
||
s.province as province,
|
||
s.school_nature as schoolNature,
|
||
s.institution_type as institutionType
|
||
FROM %s cm
|
||
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
|
||
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
|
||
LEFT JOIN yx_school s ON s.id = sc.school_id
|
||
WHERE 1=1 AND cm.state > 0
|
||
`, tableName)
|
||
|
||
params := []interface{}{}
|
||
|
||
// 注意:移除了 params = append(params, query.UserScoreVO.CalculationTableName) 因为表名已经通过 Sprintf 插入
|
||
|
||
if query.UserScoreVO.ID != "" {
|
||
countSQL += " AND cm.score_id = ?"
|
||
sql += " AND cm.score_id = ?"
|
||
params = append(params, query.UserScoreVO.ID)
|
||
}
|
||
|
||
if query.MajorType != "" {
|
||
countSQL += " AND cm.major_type = ?"
|
||
sql += " AND cm.major_type = ?"
|
||
params = append(params, query.MajorType)
|
||
}
|
||
if query.Category != "" {
|
||
countSQL += " AND cm.category = ?"
|
||
sql += " AND cm.category = ?"
|
||
params = append(params, query.Category)
|
||
}
|
||
if len(query.MajorTypeChildren) > 0 {
|
||
placeholders := strings.Repeat("?,", len(query.MajorTypeChildren)-1) + "?"
|
||
countSQL += " AND cm.major_type_child IN (" + placeholders + ")"
|
||
sql += " AND cm.major_type_child IN (" + placeholders + ")"
|
||
for _, v := range query.MajorTypeChildren {
|
||
params = append(params, v)
|
||
}
|
||
}
|
||
|
||
if query.MainSubjects != "" {
|
||
countSQL += " AND cm.main_subjects = ?"
|
||
sql += " AND cm.main_subjects = ?"
|
||
params = append(params, query.MainSubjects)
|
||
}
|
||
|
||
// 录取概率
|
||
switch query.Probability {
|
||
case "难录取":
|
||
countSQL += " AND cm.enroll_probability < 60"
|
||
sql += " AND cm.enroll_probability < 60"
|
||
case "可冲击":
|
||
countSQL += " AND (cm.enroll_probability >= 60 and cm.enroll_probability < 73)"
|
||
sql += " AND (cm.enroll_probability >= 60 and cm.enroll_probability < 73)"
|
||
case "较稳妥":
|
||
countSQL += " AND (cm.enroll_probability >= 73 and cm.enroll_probability < 93)"
|
||
sql += " AND (cm.enroll_probability >= 73 and cm.enroll_probability < 93)"
|
||
case "可保底":
|
||
countSQL += " AND (cm.enroll_probability >= 93)"
|
||
sql += " AND (cm.enroll_probability >= 93)"
|
||
}
|
||
|
||
// 移除了无效的 strings.Replace
|
||
|
||
var wg sync.WaitGroup
|
||
var countErr, queryErr error
|
||
|
||
wg.Add(2)
|
||
// 协程1:COUNT 查询
|
||
go func() {
|
||
defer wg.Done()
|
||
countErr = config.DB.Raw(countSQL, params...).Count(&total).Error
|
||
}()
|
||
|
||
// 协程2:主查询
|
||
go func() {
|
||
defer wg.Done()
|
||
sql += fmt.Sprintf(" LIMIT %d OFFSET %d", query.Size, (query.Page-1)*query.Size)
|
||
queryErr = config.DB.Raw(sql, params...).Scan(&items).Error
|
||
}()
|
||
wg.Wait()
|
||
if countErr != nil || queryErr != nil {
|
||
return nil, 0, fmt.Errorf("countErr: %v, queryErr: %v", countErr, queryErr)
|
||
}
|
||
return items, total, queryErr
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) FindByID(id string) (*entity.YxCalculationMajor, error) {
|
||
var item entity.YxCalculationMajor
|
||
err := config.DB.First(&item, "id = ?", id).Error
|
||
return &item, err
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) Create(item *entity.YxCalculationMajor) error {
|
||
return config.DB.Create(item).Error
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) Update(item *entity.YxCalculationMajor) error {
|
||
return config.DB.Save(item).Error
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) UpdateFields(id string, fields map[string]interface{}) error {
|
||
return config.DB.Model(&entity.YxCalculationMajor{}).Where("id = ?", id).Updates(fields).Error
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) Delete(id string) error {
|
||
return config.DB.Delete(&entity.YxCalculationMajor{}, "id = ?", id).Error
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) FindByScoreID(scoreID string) ([]entity.YxCalculationMajor, error) {
|
||
var items []entity.YxCalculationMajor
|
||
err := config.DB.Where("score_id = ?", scoreID).Find(&items).Error
|
||
return items, err
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) BatchCreate(tableName string, items []entity.YxCalculationMajor, batchSize int) error {
|
||
if tableName != "" {
|
||
return config.DB.Table(tableName).CreateInBatches(items, batchSize).Error
|
||
}
|
||
return config.DB.CreateInBatches(items, batchSize).Error
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) BatchUpdate(items []entity.YxCalculationMajor) error {
|
||
return config.DB.Save(items).Error
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) BatchUpsert(items []entity.YxCalculationMajor, updateColumns []string) error {
|
||
return config.DB.Clauses(clause.OnConflict{
|
||
Columns: []clause.Column{{Name: "id"}},
|
||
DoUpdates: clause.AssignmentColumns(updateColumns),
|
||
}).CreateInBatches(items, 100).Error
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) BatchDelete(ids []string) error {
|
||
return config.DB.Delete(&entity.YxCalculationMajor{}, "id IN ?", ids).Error
|
||
}
|
||
|
||
func (m *YxCalculationMajorMapper) DeleteByScoreID(scoreID string) error {
|
||
return config.DB.Delete(&entity.YxCalculationMajor{}, "score_id = ?", scoreID).Error
|
||
}
|