C++:原子标记指针,高位为16位计数器,似乎无法递增计数器

qmelpv7a  于 2022-12-20  发布在  其他
关注(0)|答案(1)|浏览(164)

为了便于学习,我正在尝试实现一个原子标记/封装指针。
我想将高16位用于uint16_t计数器,将低3位用于3位标记。
到目前为止,除了计数器的递增功能之外,我已经设法让所有的东西都工作了。我对位操作不是很熟悉,所以我认为错误可能是在我使用它们的某个地方。
我的实现如下:

当前输出为:

AtomicTaggedPointer(ptr=0x5589e2dc92b0, val=42, tag=4, count=0) 
0000000000000000 0101010110001001 1110001011011100 1001001010110 100

AtomicTaggedPointer(ptr=0x5589e2dca2e0, val=43, tag=5, count=0) 
0000000000000000 0101010110001001 1110001011011100 1010001011100 101

我想实现的是,当第二次打印时,count=1,我们看到它存储在高16位。

#include <atomic>
#include <cassert>
#include <cstdint>
#include <cstdio>

// A word-aligned, atomic tagged pointer.
// Uses both the upper 16 bits for storage, and the lower 3 bits for tagging.
//
//   64                48                32                16
// 0xXXXXXXXXXXXXXXXX  0000000000000000  0000000000000000  0000000000000XXX
//   ^                 ^                                                ^
//   |                 |                                                +-- Tag (3 bits)
//   |                 +-- Pointer (48 bits)
//   +-- Counter (16 bits)
//
//
// The tag is 3 bits, allowing for up to 8 different tags.
//
// The version is incremented every time the pointer is CAS'd. This is useful
// for detecting concurrent modifications to a pointer.
template <typename T>
struct AtomicTaggedPointer
{
    static_assert(sizeof(T*) == 8, "T* must be 8 bytes");
    static_assert(sizeof(std::atomic<uintptr_t>) == 8, "uintptr_t must be 8 bytes");

  private:
    static constexpr uintptr_t kTagMask      = 0x7;                // 3 bits
    static constexpr uintptr_t kCounterMask  = 0xFFFF000000000000; // 16 bits
    static constexpr uintptr_t kPointerMask  = ~kTagMask;          // All bits except tag bits
    static constexpr uintptr_t kCounterShift = 48;                 // Shift counter bits to the left

    std::atomic<uintptr_t> value;

  public:
    AtomicTaggedPointer(T* ptr, uint8_t tag = 0)
    {
        value.store(reinterpret_cast<uintptr_t>(ptr) | tag, std::memory_order_relaxed);
    }

    T* get() const
    {
        return reinterpret_cast<T*>(value.load(std::memory_order_relaxed) & kPointerMask);
    }

    uint8_t tag() const
    {
        return value.load(std::memory_order_relaxed) & kTagMask;
    }

    uint16_t counter() const
    {
        return value.load(std::memory_order_relaxed) >> kCounterShift;
    }

    // Compare and swap the pointer with the desired value, and optionally set the tag.
    // Returns true if the swap was successful.
    // Also increments the version counter by 1.
    bool cas(T* desired, uint8_t tag = 0)
    {
        uintptr_t expected = value.load(std::memory_order_relaxed);
        uintptr_t desired_value =
            reinterpret_cast<uintptr_t>(desired) | (tag & kTagMask) | ((expected + 1) & kCounterMask) << 48;
        return value.compare_exchange_strong(expected, desired_value, std::memory_order_relaxed);
    }

    void print() const
    {
        printf("AtomicTaggedPointer(ptr=%p, val=%d, tag=%hhu, count=%hu) \n", get(), *get(), tag(), counter());
        // Print each bit of the pointer's 64-bit value
        // In the format:
        // 0xXXXXXXXXXXXXXXXX  0000000000000000  0000000000000000  0000000000000XXX
        uintptr_t v = value.load(std::memory_order_relaxed);
        for (int i = 63; i >= 0; i--)
        {
            if (i == 2 || i == 15 || i == 31 || i == 47)
            {
                printf(" ");
            }
            printf("%lu", (v >> i) & 1);
        }
        printf("\n");
    }
};

int
main()
{
    AtomicTaggedPointer<int> p = AtomicTaggedPointer<int>(new int(42), 4);
    p.print();
    assert(p.get() != nullptr);
    assert(*p.get() == 42);
    assert(p.tag() == 4);
    assert(p.counter() == 0);

    int* expected = p.get();
    p.cas(new int(43), 5);
    p.print();
}
wbgh16ku

wbgh16ku1#

通过增加1来增加expected。这将是最低位的增量。但这是您放置标记的位置。计数器位于最高位。因此您需要首先将1移位到计数器的最低位,即移位kCounterShift(为此,您应该首先将1强制转换为适当的类型uintptr_t,以确保移位在范围内)。
此外,您的kPointerMask是错误的,因为它没有屏蔽计数器位。
另外,为了确保指针正确对齐,你应该在alignof(T)上添加一个足够大的静态Assert,这样你就可以放心了,因为在标准C++中不可能为一个类型传递有效的未对齐指针值,即使在允许未对齐访问的平台上,也不允许像这样传递和解引用未正确对齐的指针。(参见答案下面的注解)
在您的特定示例中,int将无法满足该要求。它将仅对齐4字节,而不是您需要的8字节。您使用new创建对象可能会避免您获得未对齐8字节的指针,但即使对于new int,这通常也无法保证。
当然,这里有很多是实现定义的行为。例如,所使用的指针中的位的布局以及它们如何Map到uintptr_t是特定于x86-64和典型ABI的。我可以想象编译器使用指针中未使用的位来达到某种目的或类似目的。
可能还有一个更广泛的问题,即整个方法是否与指针出处兼容,对此我并不确定。

相关问题