原子操作:无锁并发编程的艺术

深入理解 Go 的 sync/atomic 包,掌握原子操作和无锁并发编程技巧

原子操作:无锁并发编程的艺术

在并发编程中,我们经常需要保护共享数据。传统的做法是使用互斥锁(mutex),但锁会带来性能开销和死锁风险。原子操作提供了一种无锁的替代方案,在某些场景下性能更优。

本文将深入探讨 Go 的 sync/atomic 包,学习如何正确使用原子操作。

什么是原子操作?

原子操作是指不可被中断的操作。在执行过程中,不会被其他 goroutine 打断:

// 非原子操作(可能被中断)
counter++  // 实际上是:读取 → 加 1 → 写入

// 原子操作(不可中断)
atomic.AddInt64(&counter, 1)  // 硬件级别的原子指令

为什么需要原子操作?

package main

import (
    "fmt"
    "sync"
)

func main() {
    var counter int64
    var wg sync.WaitGroup
    
    // 启动 1000 个 goroutine
    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            counter++  // ❌ 竞态条件!
        }()
    }
    
    wg.Wait()
    fmt.Println("Counter:", counter)  // 结果不确定,通常小于 1000
}

sync/atomic 基础

基本类型操作

package main

import (
    "fmt"
    "sync"
    "sync/atomic"
)

func main() {
    var counter int64
    var wg sync.WaitGroup
    
    // 使用原子操作
    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            atomic.AddInt64(&counter, 1)  // ✅ 安全
        }()
    }
    
    wg.Wait()
    fmt.Println("Counter:", counter)  // 1000
}

常用操作

package main

import (
    "fmt"
    "sync/atomic"
)

func main() {
    var value int64
    
    // Store:存储值
    atomic.StoreInt64(&value, 42)
    fmt.Println("Stored:", value)
    
    // Load:读取值
    loaded := atomic.LoadInt64(&value)
    fmt.Println("Loaded:", loaded)
    
    // Add:增加值
    atomic.AddInt64(&value, 10)
    fmt.Println("After Add:", atomic.LoadInt64(&value))
    
    // Swap:交换值,返回旧值
    old := atomic.SwapInt64(&value, 100)
    fmt.Printf("Swapped: old=%d, new=%d\n", old, atomic.LoadInt64(&value))
    
    // CompareAndSwap(CAS):比较并交换
    // 只有当前值等于期望值时才交换
    success := atomic.CompareAndSwapInt64(&value, 100, 200)
    fmt.Printf("CAS: success=%v, value=%d\n", success, atomic.LoadInt64(&value))
    
    // CAS 失败的情况
    success = atomic.CompareAndSwapInt64(&value, 100, 300)
    fmt.Printf("CAS: success=%v, value=%d\n", success, atomic.LoadInt64(&value))
}

Go 1.19 新的 atomic 类型

类型安全的原子操作

package main

import (
    "fmt"
    "sync/atomic"
)

func main() {
    // Go 1.19 引入的新类型
    var (
        i atomic.Int32
        u atomic.Uint64
        b atomic.Bool
    )
    
    // 类型安全的操作
    i.Store(42)
    fmt.Println("Int32:", i.Load())
    
    u.Add(100)
    fmt.Println("Uint64:", u.Load())
    
    b.Store(true)
    fmt.Println("Bool:", b.Load())
    
    // CompareAndSwap
    if i.CompareAndSwap(42, 100) {
        fmt.Println("Swapped to:", i.Load())
    }
    
    // Swap
    old := i.Swap(200)
    fmt.Printf("Swapped: old=%d, new=%d\n", old, i.Load())
}

atomic.Pointer

package main

import (
    "fmt"
    "sync/atomic"
)

type Config struct {
    Timeout  int
    MaxRetry int
}

func main() {
    // 原子指针
    var configPtr atomic.Pointer[Config]
    
    // 存储配置
    config1 := &Config{Timeout: 30, MaxRetry: 3}
    configPtr.Store(config1)
    
    // 读取配置
    current := configPtr.Load()
    fmt.Printf("Config: %+v\n", current)
    
    // 原子更新配置
    config2 := &Config{Timeout: 60, MaxRetry: 5}
    old := configPtr.Swap(config2)
    
    fmt.Printf("Old config: %+v\n", old)
    fmt.Printf("New config: %+v\n", configPtr.Load())
    
    // CompareAndSwap
    config3 := &Config{Timeout: 90, MaxRetry: 10}
    if configPtr.CompareAndSwap(config2, config3) {
        fmt.Println("Config updated successfully")
    }
}

实战:无锁数据结构

1. 原子计数器

package main

import (
    "fmt"
    "sync"
    "sync/atomic"
)

// Counter 线程安全的计数器
type Counter struct {
    value atomic.Int64
}

func (c *Counter) Increment() {
    c.value.Add(1)
}

func (c *Counter) Decrement() {
    c.value.Add(-1)
}

func (c *Counter) Value() int64 {
    return c.value.Load()
}

func (c *Counter) Reset() {
    c.value.Store(0)
}

func main() {
    counter := &Counter{}
    var wg sync.WaitGroup
    
    // 并发增加
    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            counter.Increment()
        }()
    }
    
    wg.Wait()
    fmt.Println("Final count:", counter.Value())  // 1000
    
    // 并发增减
    for i := 0; i < 500; i++ {
        wg.Add(2)
        go func() {
            defer wg.Done()
            counter.Increment()
        }()
        go func() {
            defer wg.Done()
            counter.Decrement()
        }()
    }
    
    wg.Wait()
    fmt.Println("After increment/decrement:", counter.Value())  // 1000
}

2. 无锁队列

package main

import (
    "fmt"
    "sync/atomic"
    "unsafe"
)

// Node 队列节点
type Node struct {
    value interface{}
    next  unsafe.Pointer
}

// LockFreeQueue 无锁队列
type LockFreeQueue struct {
    head unsafe.Pointer
    tail unsafe.Pointer
}

func NewLockFreeQueue() *LockFreeQueue {
    dummy := &Node{}
    return &LockFreeQueue{
        head: unsafe.Pointer(dummy),
        tail: unsafe.Pointer(dummy),
    }
}

// Enqueue 入队
func (q *LockFreeQueue) Enqueue(value interface{}) {
    node := &Node{value: value}
    
    for {
        tail := (*Node)(atomic.LoadPointer(&q.tail))
        next := atomic.LoadPointer(&tail.next)
        
        // 检查 tail 是否仍然是最后一个
        if tail == (*Node)(atomic.LoadPointer(&q.tail)) {
            if next == nil {
                // 尝试将新节点链接到 tail
                if atomic.CompareAndSwapPointer(&tail.next, nil, unsafe.Pointer(node)) {
                    // 成功,尝试更新 tail
                    atomic.CompareAndSwapPointer(&q.tail, unsafe.Pointer(tail), unsafe.Pointer(node))
                    return
                }
            } else {
                // tail 不是最后一个,更新 tail
                atomic.CompareAndSwapPointer(&q.tail, unsafe.Pointer(tail), next)
            }
        }
    }
}

// Dequeue 出队
func (q *LockFreeQueue) Dequeue() (interface{}, bool) {
    for {
        head := (*Node)(atomic.LoadPointer(&q.head))
        tail := (*Node)(atomic.LoadPointer(&q.tail))
        next := (*Node)(atomic.LoadPointer(&head.next))
        
        // 检查一致性
        if head == (*Node)(atomic.LoadPointer(&q.head)) {
            if head == tail {
                if next == nil {
                    return nil, false  // 队列为空
                }
                // tail 落后了,更新它
                atomic.CompareAndSwapPointer(&q.tail, unsafe.Pointer(tail), unsafe.Pointer(next))
            } else {
                // 读取值,然后尝试更新 head
                value := next.value
                if atomic.CompareAndSwapPointer(&q.head, unsafe.Pointer(head), unsafe.Pointer(next)) {
                    return value, true
                }
            }
        }
    }
}

func main() {
    q := NewLockFreeQueue()
    
    // 入队
    q.Enqueue("first")
    q.Enqueue("second")
    q.Enqueue("third")
    
    // 出队
    for {
        value, ok := q.Dequeue()
        if !ok {
            break
        }
        fmt.Println("Dequeued:", value)
    }
}

3. 无锁栈

package main

import (
    "fmt"
    "sync"
    "sync/atomic"
    "unsafe"
)

type StackNode struct {
    value interface{}
    next  unsafe.Pointer
}

type LockFreeStack struct {
    top unsafe.Pointer
}

func NewLockFreeStack() *LockFreeStack {
    return &LockFreeStack{}
}

// Push 压栈
func (s *LockFreeStack) Push(value interface{}) {
    node := &StackNode{value: value}
    
    for {
        top := atomic.LoadPointer(&s.top)
        node.next = top
        
        if atomic.CompareAndSwapPointer(&s.top, top, unsafe.Pointer(node)) {
            return
        }
    }
}

// Pop 弹栈
func (s *LockFreeStack) Pop() (interface{}, bool) {
    for {
        top := atomic.LoadPointer(&s.top)
        if top == nil {
            return nil, false
        }
        
        node := (*StackNode)(top)
        next := atomic.LoadPointer(&node.next)
        
        if atomic.CompareAndSwapPointer(&s.top, top, next) {
            return node.value, true
        }
    }
}

func main() {
    stack := NewLockFreeStack()
    var wg sync.WaitGroup
    
    // 并发压栈
    for i := 0; i < 100; i++ {
        wg.Add(1)
        go func(val int) {
            defer wg.Done()
            stack.Push(val)
        }(i)
    }
    
    wg.Wait()
    
    // 并发弹栈
    count := int64(0)
    for i := 0; i < 100; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            if _, ok := stack.Pop(); ok {
                atomic.AddInt64(&count, 1)
            }
        }()
    }
    
    wg.Wait()
    fmt.Println("Popped:", count, "items")  // 100
}

CAS 循环模式

乐观锁实现

package main

import (
    "fmt"
    "sync/atomic"
)

type Account struct {
    balance atomic.Int64
}

func (a *Account) Deposit(amount int64) {
    a.balance.Add(amount)
}

// Withdraw 使用 CAS 循环实现取款
func (a *Account) Withdraw(amount int64) bool {
    for {
        current := a.balance.Load()
        if current < amount {
            return false  // 余额不足
        }
        
        // 尝试原子更新
        if a.balance.CompareAndSwap(current, current-amount) {
            return true
        }
        // CAS 失败,重试
    }
}

func (a *Account) Balance() int64 {
    return a.balance.Load()
}

func main() {
    account := &Account{}
    account.Deposit(1000)
    
    fmt.Println("Initial balance:", account.Balance())
    
    // 尝试取款
    if account.Withdraw(300) {
        fmt.Println("Withdrew 300, new balance:", account.Balance())
    }
    
    if !account.Withdraw(800) {
        fmt.Println("Failed to withdraw 800, balance:", account.Balance())
    }
}

无锁链表

package main

import (
    "fmt"
    "sync/atomic"
    "unsafe"
)

type ListNode struct {
    value int
    next  unsafe.Pointer
}

type LockFreeList struct {
    head unsafe.Pointer
}

func NewLockFreeList() *LockFreeList {
    return &LockFreeList{}
}

// Insert 在头部插入
func (l *LockFreeList) Insert(value int) {
    node := &ListNode{value: value}
    
    for {
        head := atomic.LoadPointer(&l.head)
        node.next = head
        
        if atomic.CompareAndSwapPointer(&l.head, head, unsafe.Pointer(node)) {
            return
        }
    }
}

// Contains 检查是否包含某个值
func (l *LockFreeList) Contains(value int) bool {
    current := (*ListNode)(atomic.LoadPointer(&l.head))
    
    for current != nil {
        if current.value == value {
            return true
        }
        current = (*ListNode)(atomic.LoadPointer(&current.next))
    }
    
    return false
}

// Print 打印链表
func (l *LockFreeList) Print() {
    current := (*ListNode)(atomic.LoadPointer(&l.head))
    
    for current != nil {
        fmt.Printf("%d -> ", current.value)
        current = (*ListNode)(atomic.LoadPointer(&current.next))
    }
    fmt.Println("nil")
}

func main() {
    list := NewLockFreeList()
    
    list.Insert(3)
    list.Insert(2)
    list.Insert(1)
    
    list.Print()  // 1 -> 2 -> 3 -> nil
    
    fmt.Println("Contains 2:", list.Contains(2))  // true
    fmt.Println("Contains 5:", list.Contains(5))  // false
}

性能对比

原子操作 vs 互斥锁

package main

import (
    "fmt"
    "sync"
    "sync/atomic"
    "testing"
)

// 使用互斥锁
type MutexCounter struct {
    mu    sync.Mutex
    value int64
}

func (c *MutexCounter) Increment() {
    c.mu.Lock()
    c.value++
    c.mu.Unlock()
}

func (c *MutexCounter) Value() int64 {
    c.mu.Lock()
    defer c.mu.Unlock()
    return c.value
}

// 使用原子操作
type AtomicCounter struct {
    value atomic.Int64
}

func (c *AtomicCounter) Increment() {
    c.value.Add(1)
}

func (c *AtomicCounter) Value() int64 {
    return c.value.Load()
}

func BenchmarkMutexCounter(b *testing.B) {
    counter := &MutexCounter{}
    b.RunParallel(func(pb *testing.PB) {
        for pb.Next() {
            counter.Increment()
        }
    })
}

func BenchmarkAtomicCounter(b *testing.B) {
    counter := &AtomicCounter{}
    b.RunParallel(func(pb *testing.PB) {
        for pb.Next() {
            counter.Increment()
        }
    })
}

func main() {
    fmt.Println("Run: go test -bench=. -benchmem")
    // 原子操作通常快 5-10 倍
}

注意事项和陷阱

1. ABA 问题

package main

import (
    "fmt"
    "sync/atomic"
)

// ABA 问题:值从 A 变成 B,又变回 A
// CAS 会认为值没有变化

func main() {
    var value atomic.Int64
    value.Store(1)
    
    // Goroutine 1
    go func() {
        current := value.Load()  // 读取 1
        // 被中断...
        
        // Goroutine 2 修改了值
        value.Store(2)
        value.Store(1)  // 又改回 1
        
        // Goroutine 1 继续
        if value.CompareAndSwap(current, 10) {
            fmt.Println("CAS succeeded (but value was changed!)")
        }
    }()
    
    // 解决方案:使用版本号或 atomic.Pointer
}

2. 内存顺序

package main

import (
    "sync/atomic"
)

// Go 的 atomic 操作使用强一致性的内存顺序
// 等价于 C++ 的 memory_order_seq_cst

func main() {
    var a, b atomic.Int64
    
    // Go 保证了操作的顺序性
    a.Store(1)
    b.Store(2)
    
    // 其他 goroutine 看到的顺序要么是 a=1,b=0
    // 要么是 a=1,b=2,不会出现 a=0,b=2
}

3. 不要滥用原子操作

// ❌ 不好:复杂逻辑使用原子操作
type User struct {
    name atomic.Value
    age  atomic.Int64
}

func (u *User) Update(name string, age int64) {
    // 无法保证 name 和 age 的原子性更新
    u.name.Store(name)
    u.age.Store(age)
}

// ✅ 好:复杂逻辑使用互斥锁
type User struct {
    mu   sync.RWMutex
    name string
    age  int64
}

func (u *User) Update(name string, age int64) {
    u.mu.Lock()
    u.name = name
    u.age = age
    u.mu.Unlock()
}

最佳实践

1. 简单计数用 atomic

type Metrics struct {
    requests atomic.Uint64
    errors   atomic.Uint64
}

func (m *Metrics) RecordRequest() {
    m.requests.Add(1)
}

func (m *Metrics) RecordError() {
    m.errors.Add(1)
}

2. 配置更新用 atomic.Pointer

type Config atomic.Pointer[AppConfig]

func (c *Config) Update(newConfig *AppConfig) {
    c.Store(newConfig)
}

func (c *Config) Get() *AppConfig {
    return c.Load()
}

3. 状态标志用 atomic.Bool

type Service struct {
    running atomic.Bool
}

func (s *Service) Start() {
    s.running.Store(true)
}

func (s *Service) Stop() {
    s.running.Store(false)
}

func (s *Service) IsRunning() bool {
    return s.running.Load()
}

总结

原子操作是无锁并发编程的基础:

优势:

  • 比互斥锁更快(无上下文切换)
  • 不会死锁
  • 适合简单的并发控制

适用场景:

  • 计数器
  • 标志位
  • 单个值的原子更新
  • 无锁数据结构

不适用场景:

  • 需要同时更新多个值
  • 复杂的业务逻辑
  • 需要长时间持有"锁"

最佳实践:

  1. 使用 Go 1.19+ 的新 atomic 类型
  2. 理解 CAS 循环模式
  3. 注意 ABA 问题
  4. 不要过度使用

记住:原子操作是工具,不是银弹。在合适的场景使用合适的工具,才能写出既高效又正确的并发程序。

继续阅读

探索更多技术文章

浏览归档,发现更多关于系统设计、工具链和工程实践的内容。

全部文章 返回首页