diff --git a/server/common/base_service.go b/server/common/base_service.go index d2b1f4a..a5bbef3 100644 --- a/server/common/base_service.go +++ b/server/common/base_service.go @@ -89,13 +89,20 @@ func (s *BaseService[T]) BatchDelete(ids []string) error { // setID 通过反射设置 ID 字段 func setID(item interface{}) error { val := reflect.ValueOf(item).Elem() + + // 如果当前类型是指针,再次解引用以获取实际的 Struct + if val.Kind() == reflect.Ptr { + val = val.Elem() + } + if val.Kind() == reflect.Struct { idField := val.FieldByName("ID") if idField.IsValid() && idField.Kind() == reflect.String { + // 如果 ID 为空,生成新 ID if idField.String() == "" { idField.SetString(GenerateStringID()) } } } return nil -} \ No newline at end of file +} diff --git a/server/common/id_utils.go b/server/common/id_utils.go index b6b3118..1e2b165 100644 --- a/server/common/id_utils.go +++ b/server/common/id_utils.go @@ -4,140 +4,81 @@ import ( "errors" "strconv" "sync" - "time" + + "server/common/snowflake" ) -/** -分布式/多实例运行(重要) -如果你有多个 Pod 或服务器,必须在程序启动时给它们分配不同的 ID,否则还是会冲突 - -func main() { - // 获取当前机器的编号,比如从配置文件或环境变量读取 - // 假设这是第 2 号机器 - myMachineID := int64(2) - - // 初始化 - err := common.InitGenerator(myMachineID) - if err != nil { - panic(err) - } - - // 之后随处调用 - println(common.GenerateStringID()) -} -*/ - -// 定义常量 (标准雪花算法配置) -const ( - epoch = int64(1704067200000) // 起始时间戳 2024-01-01 - workerBits = uint(10) // 机器ID位数 - sequenceBits = uint(12) // 序列号位数 - maxWorker = -1 ^ (-1 << workerBits) - maxSequence = -1 ^ (-1 << sequenceBits) - workerShift = sequenceBits - timestampShift = sequenceBits + workerBits -) - -type IDGenerator struct { - mu sync.Mutex - lastTime int64 - workerID int64 - sequence int64 -} - var ( - defaultGenerator *IDGenerator + defaultSnowflake *snowflake.Snowflake once sync.Once ) -// InitGenerator 初始化单例生成器 -// 修复:校验逻辑移到 once.Do 外部,防止校验失败消耗掉 once 的执行机会 -func InitGenerator(workerID int64) error { - // 1. 先校验,如果失败直接返回,不要触碰 once - if workerID < 0 || workerID > int64(maxWorker) { - return errors.New("worker ID excess of limit (0-1023)") +// InitGenerator 初始化雪花算法生成器 +// workerId: 工作机器ID (0 ~ 31) +// datacenterId: 数据中心ID (0 ~ 31) +// 如果不需要区分数据中心,可以将 datacenterId 设置为 0 +func InitGenerator(workerId, datacenterId int64) error { + // 先校验参数 + if workerId < 0 || workerId > 31 { + return errors.New("workerId must be between 0 and 31") + } + if datacenterId < 0 || datacenterId > 31 { + return errors.New("datacenterId must be between 0 and 31") } - // 2. 执行初始化 + // 执行初始化 once.Do(func() { - defaultGenerator = &IDGenerator{ - workerID: workerID, - lastTime: 0, - sequence: 0, + var err error + defaultSnowflake, err = snowflake.NewSnowflake(workerId, datacenterId) + if err != nil { + panic("InitGenerator failed: " + err.Error()) } }) + return nil } -// getInstance 获取单例 -func getInstance() *IDGenerator { - // 双重检查,虽然 once.Do 是线程安全的,但如果 InitGenerator 没被调用过, - // 我们需要确保这里能兜底初始化 - if defaultGenerator == nil { - once.Do(func() { - defaultGenerator = &IDGenerator{ - workerID: 1, // 默认机器ID,防止未初始化导致 panic - lastTime: 0, - sequence: 0, - } - }) - } - - // 【关键修复】如果经过上面的逻辑 defaultGenerator 还是 nil - // (这种情况极少见,除非 InitGenerator 曾经被错误调用且没有赋值) - // 强制创建一个临时的或抛出 panic,避免空指针崩溃 - if defaultGenerator == nil { - // 最后的兜底,防止崩溃 - return &IDGenerator{workerID: 1} - } - - return defaultGenerator +// InitGeneratorWithWorkerID 仅使用 workerId 初始化(兼容旧版本) +// datacenterId 默认为 0 +func InitGeneratorWithWorkerID(workerID int64) error { + return InitGenerator(workerID, 0) } -// GenerateLongID 全局辅助函数 +// getInstance 获取单例实例 +func getInstance() *snowflake.Snowflake { + once.Do(func() { + // 默认值:workerId=1, datacenterId=0 + var err error + defaultSnowflake, err = snowflake.NewSnowflake(1, 0) + if err != nil { + // 默认参数如果还失败,直接 panic + panic("Snowflake getInstance failed: " + err.Error()) + } + }) + + // 防御性编程:如果 once.Do 已经执行过(例如被 InitGenerator 执行了), + // 但因为 panic 或其他异常导致 defaultSnowflake 仍为 nil,这里进行补救 + if defaultSnowflake == nil { + // 此时忽略 sync.Once,直接强制初始化,防止 nil pointer crash + // 使用默认安全值 (1, 0) + defaultSnowflake, _ = snowflake.NewSnowflake(1, 0) + } + + return defaultSnowflake +} + +// GenerateLongID 生成 64 位整型 ID func GenerateLongID() int64 { - return getInstance().NextID() + id, err := getInstance().NextId() + if err != nil { + // 极端情况:时间回拨 + // 返回 0 或使用时间戳作为备用方案 + panic("GenerateLongID failed: " + err.Error()) + } + return id } -// GenerateStringID 全局辅助函数 +// GenerateStringID 生成字符串 ID func GenerateStringID() string { - return strconv.FormatInt(getInstance().NextID(), 10) -} - -// NextID 生成下一个 ID -func (g *IDGenerator) NextID() int64 { - // 防御性编程:防止 g 为 nil - if g == nil { - // 如果实例是 nil,尝试获取默认实例 - if defaultGenerator != nil { - g = defaultGenerator - } else { - // 极端情况,创建一个临时对象(虽然锁不住全局,但能防崩) - g = &IDGenerator{workerID: 1} - } - } - - g.mu.Lock() - defer g.mu.Unlock() - - now := time.Now().UnixMilli() - - if now < g.lastTime { - now = g.lastTime - } - - if now == g.lastTime { - g.sequence = (g.sequence + 1) & int64(maxSequence) - if g.sequence == 0 { - for now <= g.lastTime { - now = time.Now().UnixMilli() - } - } - } else { - g.sequence = 0 - } - - g.lastTime = now - - return ((now - epoch) << timestampShift) | (g.workerID << workerShift) | g.sequence + return strconv.FormatInt(GenerateLongID(), 10) } diff --git a/server/common/snowflake/snowflake.go b/server/common/snowflake/snowflake.go new file mode 100644 index 0000000..91df396 --- /dev/null +++ b/server/common/snowflake/snowflake.go @@ -0,0 +1,122 @@ +package snowflake + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// 定义常量 +const ( + // 位数分配 + sequenceBits = 12 // 序列号占用的位数 + workerIdBits = 5 // 工作机器ID占用的位数 + datacenterIdBits = 5 // 数据中心ID占用的位数 + + // 最大值 + maxSequence = -1 ^ (-1 << sequenceBits) // 4095 + maxWorkerId = -1 ^ (-1 << workerIdBits) // 31 + maxDatacenterId = -1 ^ (-1 << datacenterIdBits) // 31 + + // 位移偏移量 + workerIdShift = sequenceBits // 12 + datacenterIdShift = sequenceBits + workerIdBits // 12 + 5 = 17 + timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits // 12 + 5 + 5 = 22 +) + +// 起始时间戳 (纪元),可以使用程序上线的时间,这里设置为 2020-01-01 00:00:00 UTC +var epoch int64 = 1577836800000 + +// Snowflake 结构体 +type Snowflake struct { + mu sync.Mutex // 互斥锁,保证并发安全 + lastTime int64 // 上次生成ID的时间戳 + workerId int64 // 工作机器ID + datacenterId int64 // 数据中心ID + sequence int64 // 当前毫秒内的序列号 +} + +// NewSnowflake 初始化一个 Snowflake 实例 +// workerId: 工作机器ID (0 ~ 31) +// datacenterId: 数据中心ID (0 ~ 31) +func NewSnowflake(workerId, datacenterId int64) (*Snowflake, error) { + if workerId < 0 || workerId > maxWorkerId { + return nil, errors.New(fmt.Sprintf("worker Id can't be greater than %d or less than 0", maxWorkerId)) + } + if datacenterId < 0 || datacenterId > maxDatacenterId { + return nil, errors.New(fmt.Sprintf("datacenter Id can't be greater than %d or less than 0", maxDatacenterId)) + } + + return &Snowflake{ + lastTime: 0, + workerId: workerId, + datacenterId: datacenterId, + sequence: 0, + }, nil +} + +// NextId 生成下一个 ID +func (s *Snowflake) NextId() (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // 获取当前时间戳(毫秒) + now := time.Now().UnixMilli() + + // 如果当前时间小于上次生成ID的时间,说明时钟回拨,抛出异常 + if now < s.lastTime { + return 0, errors.New(fmt.Sprintf("Clock moved backwards. Refusing to generate id for %d milliseconds", s.lastTime-now)) + } + + // 如果是同一毫秒内 + if now == s.lastTime { + // 序列号自增 + s.sequence = (s.sequence + 1) & maxSequence + // 如果序列号溢出(超过4095),则等待下一毫秒 + if s.sequence == 0 { + now = s.waitNextMillis(now) + } + } else { + // 不同毫秒,序列号重置为0 + s.sequence = 0 + } + + // 更新最后时间戳 + s.lastTime = now + + // 组装 ID + // (当前时间 - 起始时间) << 时间戳位移 | 数据中心ID << 数据中心位移 | 工作ID << 工作位移 | 序列号 + id := ((now - epoch) << timestampLeftShift) | + (s.datacenterId << datacenterIdShift) | + (s.workerId << workerIdShift) | + s.sequence + + return id, nil +} + +// waitNextMillis 阻塞等待下一毫秒 +func (s *Snowflake) waitNextMillis(lastTime int64) int64 { + now := time.Now().UnixMilli() + for now <= lastTime { + now = time.Now().UnixMilli() + } + return now +} + +// ParseId 解析 ID,用于调试或查看 ID 组成部分 +func ParseId(id int64) map[string]interface{} { + timestamp := (id >> timestampLeftShift) + epoch + datacenterId := (id >> datacenterIdShift) & maxDatacenterId + workerId := (id >> workerIdShift) & maxWorkerId + sequence := id & maxSequence + + return map[string]interface{}{ + "id": id, + "timestamp": timestamp, + "time_str": time.UnixMilli(timestamp).Format("2006-01-02 15:04:05.000"), + "datacenterId": datacenterId, + "workerId": workerId, + "sequence": sequence, + } +} diff --git a/server/common/snowflake/snowflake_test.go b/server/common/snowflake/snowflake_test.go new file mode 100644 index 0000000..8c61648 --- /dev/null +++ b/server/common/snowflake/snowflake_test.go @@ -0,0 +1,37 @@ +package snowflake // 注意:这里必须是 package snowflake,不能是 main + +import ( + "fmt" + "testing" +) + +// 这是一个测试函数,用于验证功能 +func TestGenerateID(t *testing.T) { + // 1. 初始化生成器 + sf, err := NewSnowflake(1, 1) + if err != nil { + t.Fatalf("初始化失败: %v", err) + } + + fmt.Println("=== 开始生成 ID ===") + + // 2. 生成几个 ID + for i := 0; i < 5; i++ { + id, err := sf.NextId() + if err != nil { + t.Errorf("生成 ID 失败: %v", err) + } else { + fmt.Printf("生成 ID: %d\n", id) + } + } + + // 3. 解析 ID 查看详情 + id, _ := sf.NextId() + info := ParseId(id) + fmt.Printf("\nID 详情解析:\n") + fmt.Printf("ID: %d\n", info["id"]) + fmt.Printf("时间: %s\n", info["time_str"]) + fmt.Printf("数据中心: %d\n", info["datacenterId"]) + fmt.Printf("工作机器: %d\n", info["workerId"]) + fmt.Printf("序列号: %d\n", info["sequence"]) +} diff --git a/server/config/config.dev.yaml b/server/config/config.dev.yaml index c0d1927..5cc4ef9 100644 --- a/server/config/config.dev.yaml +++ b/server/config/config.dev.yaml @@ -1,5 +1,7 @@ server: port: 8081 + worker_id: 1 # 工作机器ID (0-31),单实例使用1 + datacenter_id: 0 # 数据中心ID (0-31),默认0 # 雪花算法机器ID (0-1023),分布式环境下不同实例需设置不同值 log: level: debug diff --git a/server/config/config.go b/server/config/config.go index 886733e..3cef927 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -30,7 +30,9 @@ type LogConfig struct { // ServerConfig 服务配置 type ServerConfig struct { - Port int `yaml:"port"` // 服务端口 + Port int `yaml:"port"` // 服务端口 + WorkerID int `yaml:"worker_id"` // 工作机器ID (0-31),用于雪花算法 + DatacenterID int `yaml:"datacenter_id"` // 数据中心ID (0-31),用于雪花算法 } // SecurityConfig 安全配置 diff --git a/server/config/config.prod.yaml b/server/config/config.prod.yaml index cb987a3..eecef66 100644 --- a/server/config/config.prod.yaml +++ b/server/config/config.prod.yaml @@ -1,5 +1,7 @@ server: port: 8081 + worker_id: 1 # 工作机器ID (0-31),多实例部署需配置不同值 + datacenter_id: 0 # 数据中心ID (0-31),多机房部署需配置不同值 # 雪花算法机器ID (0-1023),分布式环境下不同实例需设置不同值,多实例部署时需手动配置 log: level: info diff --git a/server/config/config.test.yaml b/server/config/config.test.yaml index b81e984..169c6a9 100644 --- a/server/config/config.test.yaml +++ b/server/config/config.test.yaml @@ -1,5 +1,7 @@ server: port: 8080 + worker_id: 1 # 工作机器ID (0-31),测试环境使用1 + datacenter_id: 0 # 数据中心ID (0-31),默认0 log: level: debug diff --git a/server/main.go b/server/main.go index 264a89e..925fab4 100644 --- a/server/main.go +++ b/server/main.go @@ -41,6 +41,21 @@ func main() { common.InitLogger() common.Info("========== 应用启动 ==========") + // 初始化雪花算法ID生成器(从配置获取workerID,默认为1) + workerID := int64(config.AppConfig.Server.WorkerID) + if workerID <= 0 { + workerID = 1 // 默认workerID + } + datacenterID := int64(config.AppConfig.Server.DatacenterID) + if datacenterID < 0 { + datacenterID = 0 // 默认datacenterID + } + if err := common.InitGenerator(workerID, datacenterID); err != nil { + common.LogError("雪花算法初始化失败: %v", err) + log.Fatalf("雪花算法初始化失败: %v\n", err) + } + common.Info("雪花算法ID生成器初始化完成 (WorkerID: %d)", workerID) + // 初始化数据库 config.InitDB() common.Info("数据库初始化完成")