Golang程序中同步问题

yv5phkfx  于 2023-04-18  发布在  Go
关注(0)|答案(1)|浏览(138)

我试图创建一个程序,作为代理服务器的功能,可以动态切换到新的端点。但我面临的问题是,在调用switchOverToNewEndpoint()后,仍然有一些代理对象连接到原始端点8.8.8.8,这些端点应该被关闭。

package main

import (
    "net"
    "sync"
    "sync/atomic"
    "time"
)

type Proxy struct {
    ID       int32
    From, To *net.TCPConn
}

var switchOver int32 = 0

func SetSwitchOver() {
    atomic.StoreInt32((*int32)(&switchOver), 1)
}

func SwitchOverEnabled() bool {
    return atomic.LoadInt32((*int32)(&switchOver)) == 1
}

var proxies map[int32]*Proxy = make(map[int32]*Proxy, 0)
var proxySeq int32 = 0
var mu sync.RWMutex

func addProxy(from *net.TCPConn) {
    mu.Lock()
    proxySeq += 1
    proxy := &Proxy{ID: proxySeq, From: from}
    proxies[proxySeq] = proxy
    mu.Unlock()

    var toAddr string
    if SwitchOverEnabled() {
        toAddr = "1.1.1.1"
    } else {
        toAddr = "8.8.8.8"
    }
    tcpAddr, _ := net.ResolveTCPAddr("tcp4", toAddr)
    toConn, err := net.DialTCP("tcp", nil, tcpAddr)
    if err != nil {
        panic(err)
    }
    proxy.To = toConn
}

func switchOverToNewEndpoint() {
    mu.RLock()
    closedProxies := proxies
    mu.RUnlock()

    SetSwitchOver()
    for _, proxy := range closedProxies {
        proxy.From.Close()
        proxy.To.Close()
        mu.Lock()
        delete(proxies, proxy.ID)
        mu.Unlock()
    }
}

func main() {
    tcpAddr, _ := net.ResolveTCPAddr("tcp4", "0.0.0.0:5432")
    ln, _ := net.ListenTCP("tcp", tcpAddr)
    go func() {
        time.Sleep(time.Second * 30)
        switchOverToNewEndpoint()
    }()
    for {
        clientConn, err := ln.AcceptTCP()
        if err != nil {
            panic(err)
        }
        go addProxy(clientConn)
    }
}

想了一会儿,我猜问题出在

mu.RLock()
    closedProxies := proxies
    mu.RUnlock()

但我不确定这是否是根本原因,以及是否可以通过将其替换为以下内容来修复它:

closedProxies := make([]*Proxy, 0)
    mu.RLock()
    for _, proxy := range proxies {
        closedProxies = append(closedProxies, proxy)
    }
    mu.RUnlock()

由于这个案例很难重现,所以有专业知识的人可以提供一个想法或提示吗?任何意见都欢迎。提前感谢。

bihw5rsg

bihw5rsg1#

问题

这一改变是必要的。在原始实现中,closedProxies持有相同的map。请参见此演示:

package main

import "fmt"

func main() {
    proxies := make(map[int]int, 0)
    for i := 0; i < 10; i++ {
        proxies[i] = i
    }

    closeProxies := proxies

    proxies[10] = 10
    proxies[11] = 11

    for k := range closeProxies {
        delete(proxies, k)
    }

    fmt.Printf("items left: %d\n", len(proxies))
    // Output:
    //   items left: 0
}

但这不是根本原因,可能是在复制closeProxies之后,调用SetSwitchOver之前,添加了一个新的代理。在这种情况下,新的代理连接到旧地址,但不在closeProxies中。我认为这是根本原因。
还有另一个问题。在To字段设置之前,一个新的代理被添加到proxies。可能发生的情况是,程序希望在To字段设置之前关闭这个代理,并导致死机。

可靠的设计

其思路是将所有的端点都放到一个切片中,让每个端点管理自己的代理列表,这样我们只需要随时跟踪当前端点的索引,当我们想切换到另一个端点时,只需要更改索引并告诉过时的端点清除其代理。剩下的唯一复杂的事情是确保过时的端点可以清除其所有代理。参见以下实施:

管理员。开始

这就是思想的落实。

package main

import (
    "sync"
)

// Conn is abstraction of a connection to make Manager easy to test.
type Conn interface {
    Close() error
}

// Dialer is abstraction of a dialer to make Manager easy to test.
type Dialer interface {
    Dial(addr string) (Conn, error)
}

type Manager struct {
    // muCurrent protects the "current" member.
    muCurrent sync.RWMutex
    current   int // When current is -1, the manager is shuted down.
    endpoints []*endpoint

    // mu protects the whole Switch action.
    mu sync.Mutex
}

func NewManager(dialer Dialer, addresses ...string) *Manager {
    if len(addresses) < 2 {
        panic("a manger should handle at least 2 addresses")
    }

    endpoints := make([]*endpoint, len(addresses))
    for i, addr := range addresses {
        endpoints[i] = &endpoint{
            address: addr,
            dialer:  dialer,
        }
    }
    return &Manager{
        endpoints: endpoints,
    }
}

func (m *Manager) AddProxy(from Conn) {
    // 1. AddProxy will wait when the write lock of m.muCurrent is taken.
    // Once the write lock is released, AddProxy will connect to the new endpoint.
    // Switch only holds the write lock for a short time, and Switch is called
    // not so frequently, so AddProxy won't wait too much.
    // 2. Switch will wait if there is any AddProxy holding the read lock of
    // m.muCurrent. That means Switch waits longer. The advantage is that when
    // e.clear is called in Switch, All AddProxy requests to the old endpoint
    // are done. So it's safe to call e.clear then.
    m.muCurrent.RLock()
    defer m.muCurrent.RUnlock()

    current := m.current

    // Do not accept any new connection when m has been shutdown.
    if current == -1 {
        from.Close()
        return
    }

    m.endpoints[current].addProxy(from)
}

func (m *Manager) Switch() {
    // In a real world, Switch is called not so frequently.
    // So it's ok to add a lock here.
    // And it's necessary to make sure the old endpoint is cleared and ready
    // for use in the future.
    m.mu.Lock()
    defer m.mu.Unlock()

    // Take the write lock of m.muCurrent.
    // It waits for all the AddProxy requests holding the read lock to finish.
    m.muCurrent.Lock()
    old := m.current

    // Do nothing when m has been shutdown.
    if old == -1 {
        m.muCurrent.Unlock()
        return
    }
    next := old + 1
    if next >= len(m.endpoints) {
        next = 0
    }
    m.current = next
    m.muCurrent.Unlock()

    // When it reaches here, all AddProxy requests to the old endpoint are done.
    // And it's safe to call e.clear now.
    m.endpoints[old].clear()
}

func (m *Manager) Shutdown() {
    m.mu.Lock()
    defer m.mu.Unlock()

    m.muCurrent.Lock()
    current := m.current
    m.current = -1
    m.muCurrent.Unlock()

    m.endpoints[current].clear()
}

type proxy struct {
    from, to Conn
}

type endpoint struct {
    address string
    dialer  Dialer

    mu      sync.Mutex
    proxies []*proxy
}

func (e *endpoint) clear() {
    for _, p := range e.proxies {
        p.from.Close()
        p.to.Close()
    }

    // Assign a new slice to e.proxies, and the GC will collect the old one.
    e.proxies = []*proxy{}
}

func (e *endpoint) addProxy(from Conn) {
    toConn, err := e.dialer.Dial(e.address)
    if err != nil {
        // Close the from connection so that the client will reconnect?
        from.Close()
        return
    }

    e.mu.Lock()
    defer e.mu.Unlock()
    e.proxies = append(e.proxies, &proxy{from: from, to: toConn})
}
main.go

本demo演示了如何使用之前实现的Manager类型:

package main

import (
    "net"
    "time"
)

type realDialer struct{}

func (d realDialer) Dial(addr string) (Conn, error) {
    tcpAddr, err := net.ResolveTCPAddr("tcp4", addr)
    if err != nil {
        return nil, err
    }
    return net.DialTCP("tcp", nil, tcpAddr)
}

func main() {
    manager := NewManager(realDialer{}, "1.1.1.1", "8.8.8.8")

    tcpAddr, _ := net.ResolveTCPAddr("tcp4", "0.0.0.0:5432")
    ln, _ := net.ListenTCP("tcp", tcpAddr)

    go func() {
        for range time.Tick(30 * time.Second) {
            manager.Switch()
        }
    }()
    for {
        clientConn, err := ln.AcceptTCP()
        if err != nil {
            panic(err)
        }
        go manager.AddProxy(clientConn)
    }
}
manager_test.go

使用以下命令运行测试:go test ./... -race -count 10

package main

import (
    "errors"
    "math/rand"
    "sync"
    "sync/atomic"
    "testing"
    "time"

    "github.com/google/uuid"
)

func TestManager(t *testing.T) {
    addresses := []string{"1.1.1.1", "8.8.8.8"}
    dialer := newDialer(addresses...)
    manager := NewManager(dialer, addresses...)

    ch := make(chan int, 1)

    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        for range ch {
            manager.Switch()
        }
        wg.Done()
    }()

    count := 1000
    total := count * 10

    wg.Add(total)

    fromConn := &fakeFromConn{}
    for i := 0; i < total; i++ {
        if i%count == count-1 {
            ch <- 0
        }
        go func() {
            manager.AddProxy(fromConn)
            wg.Done()
        }()
    }
    close(ch)

    wg.Wait()

    manager.Shutdown()

    for _, s := range dialer.servers {
        left := len(s.conns)
        if left != 0 {
            t.Errorf("server %s, unexpected connections left: %d", s.addr, left)
        }
    }

    closedCount := fromConn.closedCount.Load()

    if closedCount != int32(total) {
        t.Errorf("want closed count: %d, got: %d", total, closedCount)
    }
}

type fakeFromConn struct {
    closedCount atomic.Int32
}

func (c *fakeFromConn) Close() error {
    c.closedCount.Add(1)

    return nil
}

type fakeToConn struct {
    id     uuid.UUID
    server *fakeServer
}

func (c *fakeToConn) Close() error {
    if c.id == uuid.Nil {
        return nil
    }

    c.server.removeConn(c.id)

    return nil
}

type fakeServer struct {
    addr  string
    mu    sync.Mutex
    conns map[uuid.UUID]bool
}

func (s *fakeServer) addConn() (uuid.UUID, error) {
    s.mu.Lock()
    defer s.mu.Unlock()

    id, err := uuid.NewRandom()
    if err == nil {
        s.conns[id] = true
    }
    return id, err
}

func (s *fakeServer) removeConn(id uuid.UUID) {
    s.mu.Lock()
    defer s.mu.Unlock()

    delete(s.conns, id)
}

type fakeDialer struct {
    servers map[string]*fakeServer
}

func newDialer(addresses ...string) *fakeDialer {
    servers := make(map[string]*fakeServer)
    for _, addr := range addresses {
        servers[addr] = &fakeServer{
            addr:  addr,
            conns: make(map[uuid.UUID]bool),
        }
    }
    return &fakeDialer{
        servers: servers,
    }
}

func (d *fakeDialer) Dial(addr string) (Conn, error) {
    n := rand.Intn(100)
    if n == 0 {
        return nil, errors.New("fake network error")
    }
    // Simulate network latency.
    time.Sleep(time.Duration(n) * time.Millisecond)

    s := d.servers[addr]
    id, err := s.addConn()
    if err != nil {
        return nil, err
    }
    conn := &fakeToConn{
        id:     id,
        server: s,
    }
    return conn, nil
}

相关问题