golang-yitisheng-server/server/modules/yx/mapper/yx_calculation_major_mapper.go

565 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package mapper 数据访问层
package mapper
import (
"fmt"
"server/common"
"server/config"
"server/modules/yx/dto"
"server/modules/yx/entity"
"strings"
"sync"
"time"
"gorm.io/gorm/clause"
)
type YxCalculationMajorMapper struct{}
func NewYxCalculationMajorMapper() *YxCalculationMajorMapper {
return &YxCalculationMajorMapper{}
}
// 先定义存储各协程耗时的结构体(局部使用,也可全局复用)
type QueryCostTime struct {
CountCost time.Duration // 总数量查询耗时
ProbCountCost time.Duration // 四种概率数量查询耗时
QueryCost time.Duration // 主列表查询耗时
TotalCost time.Duration // 整体总耗时
}
func (m *YxCalculationMajorMapper) FindAll(page, size int) ([]entity.YxCalculationMajor, int64, error) {
var items []entity.YxCalculationMajor
var total int64
config.DB.Model(&entity.YxCalculationMajor{}).Count(&total)
err := config.DB.Offset((page - 1) * size).Limit(size).Find(&items).Error
return items, total, err
}
// 调整返回值:新增 ProbabilityCountDTO返回列表、总数量、四种概率各自数量
func (m *YxCalculationMajorMapper) FindRecommendList(query dto.SchoolMajorQuery) ([]dto.UserMajorDTO, int64, dto.ProbabilityCountDTO, error) {
var items []dto.UserMajorDTO
var total int64
var probCount dto.ProbabilityCountDTO // 四种概率的数量统计结果
// 1. 表名合法性校验:非空 + 白名单
tableName := query.UserScoreVO.CalculationTableName
if tableName == "" {
return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("CalculationTableName is empty")
}
// if !validTableNames[] {
// return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("invalid table name: %s, potential SQL injection risk", tableName)
// }
// 2. 基础条件SQL共用过滤条件排除概率筛选
baseSQL := " WHERE 1=1 AND cm.state > 0 "
params := []interface{}{}
// 拼接共用过滤条件(与原有列表查询条件一致,保证统计结果准确性)
if query.UserScoreVO.ID != "" {
baseSQL += " AND cm.score_id = ?"
params = append(params, query.UserScoreVO.ID)
}
if query.SchoolCode != "" {
baseSQL += " AND cm.school_code = ?"
params = append(params, query.SchoolCode)
}
if query.MajorType != "" {
baseSQL += " AND cm.major_type = ?"
params = append(params, query.MajorType)
}
if query.Category != "" {
baseSQL += " AND cm.category = ?"
params = append(params, query.Category)
}
if len(query.MajorTypeChildren) > 0 {
placeholders := strings.Repeat("?,", len(query.MajorTypeChildren)-1) + "?"
baseSQL += " AND cm.major_type_child IN (" + placeholders + ")"
for _, v := range query.MajorTypeChildren {
params = append(params, v)
}
}
if "" != query.Batch {
baseSQL += " AND cm.batch = ?"
params = append(params, query.Batch)
}
if query.MainSubjects != "" {
baseSQL += " AND cm.main_subjects = ?"
params = append(params, query.MainSubjects)
}
// 3. 优化后的总数量COUNT SQL
countSQL := fmt.Sprintf(`
SELECT COUNT(cm.id) FROM %s cm
%s
`, tableName, baseSQL)
// 4. 四种概率批量统计SQL使用CASE WHEN一次查询性能最优
probCountSQL := fmt.Sprintf(`
SELECT
SUM(CASE WHEN cm.enroll_probability < 60 THEN 1 ELSE 0 END) AS hard,
SUM(CASE WHEN cm.enroll_probability >= 60 AND cm.enroll_probability < 73 THEN 1 ELSE 0 END) AS risky,
SUM(CASE WHEN cm.enroll_probability >= 73 AND cm.enroll_probability < 93 THEN 1 ELSE 0 END) AS stable,
SUM(CASE WHEN cm.enroll_probability >= 93 THEN 1 ELSE 0 END) AS safe
FROM %s cm
%s
`, tableName, baseSQL)
// 5. 主查询SQL保留原有字段和JOIN
mainSQL := fmt.Sprintf(`
SELECT
cm.id,
s.school_name,
s.school_icon,
cm.state,
cm.school_code,
cm.major_code,
cm.major_name,
cm.enrollment_code,
cm.tuition,
cm.detail as majorDetail,
cm.category,
cm.batch,
cm.private_student_converted_score as privateStudentScore,
cm.student_old_converted_score as studentScore,
cm.student_converted_score,
cm.enroll_probability,
cm.rules_enroll_probability_sx,
cm.rules_enroll_probability,
cm.probability_operator,
cm.major_type,
cm.major_type_child,
cm.plan_num,
cm.main_subjects,
cm.limitation,
cm.other_score_limitation,
s.province,
s.school_nature,
s.institution_type
FROM %s cm
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
LEFT JOIN yx_school s ON s.id = sc.school_id
%s
`, tableName, baseSQL)
// 拼接传入概率的筛选条件(兼容原有业务逻辑)
switch query.Probability {
case "难录取":
mainSQL += " AND cm.enroll_probability < 60"
case "可冲击":
mainSQL += " AND (cm.enroll_probability >= 60 and cm.enroll_probability < 73)"
case "较稳妥":
mainSQL += " AND (cm.enroll_probability >= 73 and cm.enroll_probability < 93)"
case "可保底":
mainSQL += " AND (cm.enroll_probability >= 93)"
}
mainSQL += " ORDER BY cm.enroll_probability DESC"
// 6. 分页参数合法性校验
page := query.Page
size := query.Size
if page < 1 {
page = 1
}
if size < 1 {
size = 10
}
if size > 100 {
size = 100
}
offset := (page - 1) * size
// 提前拼接分页条件,避免协程内操作共享变量
mainSQL += fmt.Sprintf(" LIMIT %d OFFSET %d", size, offset)
// 7. 协程并发执行三个查询(总数量、概率数量、主列表),提升性能
// ---------------------- 核心局部代码(替换你原来的协程块) ----------------------
var wg sync.WaitGroup
var countErr, probCountErr, queryErr error
var queryCost QueryCostTime // 存储各协程耗时
var mu sync.Mutex // 互斥锁防止多协程同时修改queryCost引发竞态问题
// 整体开始时间
totalStartTime := time.Now()
wg.Add(3)
// 协程1总数量查询单独记录耗时
go func() {
defer wg.Done()
// 记录该协程单独的开始时间
start := time.Now()
countErr = config.DB.Raw(countSQL, params...).Count(&total).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock()
queryCost.CountCost = time.Since(start)
mu.Unlock()
}()
// 协程2四种概率数量批量查询单独记录耗时
go func() {
defer wg.Done()
// 记录该协程单独的开始时间
start := time.Now()
probCountErr = config.DB.Raw(probCountSQL, params...).Scan(&probCount).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock()
queryCost.ProbCountCost = time.Since(start)
mu.Unlock()
}()
// 协程3主列表查询单独记录耗时
go func() {
defer wg.Done()
// 记录该协程单独的开始时间
start := time.Now()
queryErr = config.DB.Raw(mainSQL, params...).Scan(&items).Error
// 计算该协程耗时,通过互斥锁安全写入共享变量
mu.Lock()
queryCost.QueryCost = time.Since(start)
mu.Unlock()
}()
wg.Wait()
// 计算整体总耗时
queryCost.TotalCost = time.Since(totalStartTime)
// 打印各协程耗时和总耗时(按需输出,可注释或删除)
fmt.Printf("各查询耗时统计:\n")
fmt.Printf(" 总数量查询耗时:%v\n", queryCost.CountCost)
fmt.Printf(" 概率数量查询耗时:%v\n", queryCost.ProbCountCost)
fmt.Printf(" 主列表查询耗时:%v\n", queryCost.QueryCost)
fmt.Printf(" 整体总耗时:%v\n", queryCost.TotalCost)
// 8. 错误处理
if countErr != nil {
return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("failed to query total count: %w", countErr)
}
if probCountErr != nil {
return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("failed to query probability count: %w", probCountErr)
}
if queryErr != nil {
return nil, 0, dto.ProbabilityCountDTO{}, fmt.Errorf("failed to query recommend major list: %w", queryErr)
}
if items == nil {
items = []dto.UserMajorDTO{}
}
return items, total, probCount, nil
}
func (m *YxCalculationMajorMapper) FindRecommendList1(query dto.SchoolMajorQuery) ([]dto.UserMajorDTO, int64, error) {
var items []dto.UserMajorDTO
var total int64
// 确保表名存在,防止 SQL 注入或空表名
tableName := query.UserScoreVO.CalculationTableName
if tableName == "" {
return nil, 0, fmt.Errorf("CalculationTableName is empty")
}
// 使用 Sprintf 动态插入表名
countSQL := fmt.Sprintf(`
SELECT COUNT(cm.id) FROM %s cm
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
LEFT JOIN yx_school s ON s.id = sc.school_id
WHERE 1=1 AND cm.state > 0
`, tableName)
sql := fmt.Sprintf(`
SELECT
cm.id,
s.school_name,
s.school_icon,
cm.state,
cm.school_code,
cm.major_code,
cm.major_name,
cm.enrollment_code,
cm.tuition,
cm.detail as majorDetail,
cm.category,
cm.batch,
cm.private_student_converted_score as privateStudentScore,
cm.student_old_converted_score as studentScore,
cm.student_converted_score,
cm.enroll_probability,
cm.rules_enroll_probability_sx,
cm.rules_enroll_probability,
cm.probability_operator,
cm.major_type,
cm.major_type_child,
cm.plan_num,
cm.main_subjects,
cm.limitation,
cm.other_score_limitation,
s.province as province,
s.school_nature as schoolNature,
s.institution_type as institutionType
FROM %s cm
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
LEFT JOIN yx_school s ON s.id = sc.school_id
WHERE 1=1 AND cm.state > 0
`, tableName)
params := []interface{}{}
// 注意:移除了 params = append(params, query.UserScoreVO.CalculationTableName) 因为表名已经通过 Sprintf 插入
if query.UserScoreVO.ID != "" {
countSQL += " AND cm.score_id = ?"
sql += " AND cm.score_id = ?"
params = append(params, query.UserScoreVO.ID)
}
if query.MajorType != "" {
countSQL += " AND cm.major_type = ?"
sql += " AND cm.major_type = ?"
params = append(params, query.MajorType)
}
if query.Category != "" {
countSQL += " AND cm.category = ?"
sql += " AND cm.category = ?"
params = append(params, query.Category)
}
if len(query.MajorTypeChildren) > 0 {
placeholders := strings.Repeat("?,", len(query.MajorTypeChildren)-1) + "?"
countSQL += " AND cm.major_type_child IN (" + placeholders + ")"
sql += " AND cm.major_type_child IN (" + placeholders + ")"
for _, v := range query.MajorTypeChildren {
params = append(params, v)
}
}
if query.MainSubjects != "" {
countSQL += " AND cm.main_subjects = ?"
sql += " AND cm.main_subjects = ?"
params = append(params, query.MainSubjects)
}
// 录取概率
switch query.Probability {
case "难录取":
countSQL += " AND cm.enroll_probability < 60"
sql += " AND cm.enroll_probability < 60"
case "可冲击":
countSQL += " AND (cm.enroll_probability >= 60 and cm.enroll_probability < 73)"
sql += " AND (cm.enroll_probability >= 60 and cm.enroll_probability < 73)"
case "较稳妥":
countSQL += " AND (cm.enroll_probability >= 73 and cm.enroll_probability < 93)"
sql += " AND (cm.enroll_probability >= 73 and cm.enroll_probability < 93)"
case "可保底":
countSQL += " AND (cm.enroll_probability >= 93)"
sql += " AND (cm.enroll_probability >= 93)"
}
// 移除了无效的 strings.Replace
var wg sync.WaitGroup
var countErr, queryErr error
wg.Add(2)
// 协程1COUNT 查询
go func() {
defer wg.Done()
countErr = config.DB.Raw(countSQL, params...).Count(&total).Error
}()
// 协程2主查询
go func() {
defer wg.Done()
sql += fmt.Sprintf(" LIMIT %d OFFSET %d", query.Size, (query.Page-1)*query.Size)
queryErr = config.DB.Raw(sql, params...).Scan(&items).Error
}()
wg.Wait()
if countErr != nil || queryErr != nil {
return nil, 0, fmt.Errorf("countErr: %v, queryErr: %v", countErr, queryErr)
}
return items, total, queryErr
}
func (m *YxCalculationMajorMapper) FindByID(id string) (*entity.YxCalculationMajor, error) {
var item entity.YxCalculationMajor
err := config.DB.First(&item, "id = ?", id).Error
return &item, err
}
func (m *YxCalculationMajorMapper) Create(item *entity.YxCalculationMajor) error {
return config.DB.Create(item).Error
}
func (m *YxCalculationMajorMapper) Update(item *entity.YxCalculationMajor) error {
return config.DB.Save(item).Error
}
func (m *YxCalculationMajorMapper) UpdateFields(id string, fields map[string]interface{}) error {
return config.DB.Model(&entity.YxCalculationMajor{}).Where("id = ?", id).Updates(fields).Error
}
func (m *YxCalculationMajorMapper) Delete(id string) error {
return config.DB.Delete(&entity.YxCalculationMajor{}, "id = ?", id).Error
}
func (m *YxCalculationMajorMapper) FindByScoreID(scoreID string) ([]entity.YxCalculationMajor, error) {
var items []entity.YxCalculationMajor
err := config.DB.Where("score_id = ?", scoreID).Find(&items).Error
return items, err
}
func (m *YxCalculationMajorMapper) FindListByCompositeKeys(tableName string, keys []string, scoreId string) ([]entity.YxCalculationMajor, error) {
if len(keys) == 0 {
return nil, nil
}
// 验证表名格式(防止表名注入)
if !common.IsValidTableName(tableName) {
return nil, fmt.Errorf("无效的表名: %s", tableName)
}
// 验证和转义 score_id
if scoreId == "" {
return nil, fmt.Errorf("score_id 不能为空")
}
var items []entity.YxCalculationMajor
db := config.DB
if tableName != "" {
db = db.Table(tableName)
}
sql := "SELECT * FROM " + tableName + " WHERE score_id = ? AND (school_code, major_code, enrollment_code) IN ("
var params []interface{}
// 将 score_id 作为第一个参数
params = append(params, scoreId)
for i, key := range keys {
parts := strings.Split(key, "_")
if len(parts) != 3 {
continue
}
if i > 0 {
sql += ","
}
sql += "(?, ?, ?)"
params = append(params, parts[0], parts[1], parts[2])
}
sql += ")"
err := db.Raw(sql, params...).Scan(&items).Error
return items, err
}
func (m *YxCalculationMajorMapper) FindDtoListByCompositeKeys(tableName string, keys []string, scoreId string) ([]dto.SchoolMajorDTO, error) {
if len(keys) == 0 {
return nil, nil
}
if !common.IsValidTableName(tableName) {
return nil, fmt.Errorf("无效的表名: %s", tableName)
}
if scoreId == "" {
return nil, fmt.Errorf("score_id 不能为空")
}
var items []dto.SchoolMajorDTO
// SQL with joins to get school info
// Base Query similar to FindRecommendList but filtered by composite keys
sqlStr := fmt.Sprintf(`
SELECT
cm.id,
s.school_name,
s.school_icon,
cm.state,
cm.school_code,
cm.major_code,
cm.major_name,
cm.enrollment_code,
cm.tuition,
cm.detail as majorDetail,
cm.category,
cm.batch,
cm.private_student_converted_score as privateStudentScore,
cm.student_old_converted_score as studentScore,
cm.student_converted_score,
cm.enroll_probability,
cm.rules_enroll_probability_sx,
cm.rules_enroll_probability,
cm.probability_operator,
cm.major_type,
cm.major_type_child,
cm.plan_num,
cm.main_subjects,
cm.limitation,
cm.other_score_limitation,
s.province as province,
s.school_nature as schoolNature,
s.institution_type as institutionType
FROM %s cm
LEFT JOIN yx_school_child sc ON sc.school_code = cm.school_code
LEFT JOIN yx_school_research_teaching srt ON srt.school_id = sc.school_id
LEFT JOIN yx_school s ON s.id = sc.school_id
WHERE cm.score_id = ? AND (cm.school_code, cm.major_code, cm.enrollment_code) IN (
`, tableName)
var params []interface{}
params = append(params, scoreId)
// Build IN clause
var tuples []string
for _, key := range keys {
parts := strings.Split(key, "_")
if len(parts) != 3 {
continue
}
tuples = append(tuples, "(?, ?, ?)")
params = append(params, parts[0], parts[1], parts[2])
}
if len(tuples) == 0 {
return nil, nil
}
sqlStr += strings.Join(tuples, ",") + ")"
err := config.DB.Raw(sqlStr, params...).Scan(&items).Error
return items, err
}
func (m *YxCalculationMajorMapper) BatchCreate(tableName string, items []entity.YxCalculationMajor, batchSize int) error {
if tableName != "" {
return config.DB.Table(tableName).CreateInBatches(items, batchSize).Error
}
return config.DB.CreateInBatches(items, batchSize).Error
}
func (m *YxCalculationMajorMapper) BatchUpdate(items []entity.YxCalculationMajor) error {
return config.DB.Save(items).Error
}
func (m *YxCalculationMajorMapper) BatchUpsert(items []entity.YxCalculationMajor, updateColumns []string) error {
return config.DB.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
DoUpdates: clause.AssignmentColumns(updateColumns),
}).CreateInBatches(items, 100).Error
}
func (m *YxCalculationMajorMapper) BatchDelete(ids []string) error {
return config.DB.Delete(&entity.YxCalculationMajor{}, "id IN ?", ids).Error
}
func (m *YxCalculationMajorMapper) DeleteByScoreID(scoreID string) error {
return config.DB.Delete(&entity.YxCalculationMajor{}, "score_id = ?", scoreID).Error
}
func (m *YxCalculationMajorMapper) DeleteByScoreIDFromTable(tableName, scoreID string) error {
if tableName == "" {
return nil
}
return config.DB.Table(tableName).Where("score_id = ?", scoreID).Delete(map[string]interface{}{}).Error
}