linux 从另一个线程访问thread-local

zu0ti5jz  于 2023-10-16  发布在  Linux
关注(0)|答案(5)|浏览(108)

如何从另一个线程读取/写入线程局部变量?也就是说,在线程A中,我想访问线程B的线程本地存储区域中的变量。我知道另一个线程的ID。
变量在GCC中被声明为__thread。目标平台是Linux,但独立性可能会更好(不过特定于GCC也可以)。
由于缺少线程开始钩子,我无法在每个线程开始时简单地跟踪此值。所有线程都需要以这种方式跟踪(不仅仅是专门启动的线程)。
更高级别的 Package 器,如boost thread_local_storage或使用pthread键都不是一个选择。我需要使用真正的__thread局部变量的性能。

  • 第一个答案是错误的 *:不能用全局变量来做我想做的事情。每个线程都必须有自己的变量副本。此外,出于性能原因,这些变量必须是__thread变量(同样有效的解决方案也可以,但我不知道)。我也不控制线程入口点,因此这些线程不可能 * 注册 * 任何类型的结构。
  • Thread Local不是私有的 *:关于线程局部变量的另一个误解。这些绝对不是线程的某种 * 私有 * 变量。它们是全局可寻址的内存,但有一个限制,即它们的生存期与线程有关。来自任何线程的任何函数,如果给它们一个指向这些变量的指针,就可以修改它们。上面的问题本质上是关于如何获得指针地址。
mv1qrgav

mv1qrgav1#

如果你想要线程局部变量而不是线程局部变量,为什么不使用全局变量呢?

重要说明!

我并不是建议您使用单个全局变量来替换线程局部变量。我建议使用一个全局 * 数组 * 或其他合适的值集合来替换一个线程局部变量。
当然,您必须提供同步,但由于您希望将在线程A中修改的值公开给线程B,因此无法绕过这个问题。

更新:

__thread上的GCC文档说:
当address-of运算符应用于线程局部变量时,它在运行时计算并返回该变量的当前线程示例的地址。这样获得的地址可以由任何线程使用。当线程终止时,该线程中任何指向线程局部变量的指针都将无效。
因此,如果你坚持这样做,我想有可能在线程生成之后,从它所属的线程中获取线程局部变量的地址。然后,您可以将指向该内存位置的指针存储到一个Map中(线程id => pointer),并让其他线程以这种方式访问该变量。这假定您拥有派生线程的代码。
如果你真的很喜欢冒险,你可以尝试挖掘关于___tls_get_addr的信息(从前面提到的GCC文档链接的this PDF开始)。但是这种方法是高度特定于编译器和平台的,并且缺乏文档,因此它应该引起任何人的警觉。

gr8qqesn

gr8qqesn2#

我在寻找同样的东西。正如我看到没有人回答你的问题后,搜索了网络上的所有方式,我到达后续信息:假设在Linux(Ubuntu)上编译GCC并使用-m64,则段寄存器gs保持值0。段的隐藏部分(保存线性地址)指向线程特定的局部区域。该区域在该地址处包含该地址的地址(64位)。在较低的地址存储所有线程局部变量。地址是native_handle()。因此,为了访问线程的本地数据,您应该通过该指针进行访问。
换句话说:(char*)&variable-(char*)myThread.native_handle()+(char*)theOtherThread.native_handle()
上面的代码演示了假设g++,Linux,pthreads是:

#include <iostream>
#include <thread>
#include <sstream>

thread_local int B=0x11111111,A=0x22222222;

bool shouldContinue=false;

void code(){
    while(!shouldContinue);
    std::stringstream ss;
    ss<<" A:"<<A<<" B:"<<B<<std::endl;
    std::cout<<ss.str();
}

//#define ot(th,variable) 
//(*( (char*)&variable-(char*)(pthread_self())+(char*)(th.native_handle()) ))

int& ot(std::thread& th,int& v){
    auto p=pthread_self();
    intptr_t d=(intptr_t)&v-(intptr_t)p;
    return *(int*)((char*)th.native_handle()+d);
}

int main(int argc, char **argv)
{       

        std::thread th1(code),th2(code),th3(code),th4(code);

        ot(th1,A)=100;ot(th1,B)=110;
        ot(th2,A)=200;ot(th2,B)=210;
        ot(th3,A)=300;ot(th3,B)=310;
        ot(th4,A)=400;ot(th4,B)=410;

        shouldContinue=true;

        th1.join();
        th2.join();
        th3.join();
        th4.join();

    return 0;
}
3z6pesqy

3z6pesqy3#

这是一个老问题,但既然没有给出答案,为什么不使用一个拥有自己静态注册的类呢?

#include <mutex>
#include <thread>
#include <unordered_map>

struct foo;

static std::unordered_map<std::thread::id, foo*> foos;
static std::mutex foos_mutex;

struct foo
{
    foo()
    {
        std::lock_guard<std::mutex> lk(foos_mutex);
        foos[std::this_thread::get_id()] = this;
    }
};

static thread_local foo tls_foo;

当然,你需要在线程之间进行某种同步,以确保线程已经注册了指针,但是你可以从任何你知道线程id的线程的map中获取它。

dxxyhpgq

dxxyhpgq4#

很不幸,我从来没有找到一种方法来做到这一点。
如果没有某种类型的线程初始化钩子,似乎就没有办法到达那个指针(除了依赖于平台的ASM黑客)。

mbyulnm0

mbyulnm05#

这几乎是你所需要的,如果不修改你的要求。
在Linux上,它使用pthread_key_create,windows使用TlsAlloc。它们都是通过 key 检索本地线程的一种方式。然而,如果你注册了键,你就可以在其他线程上访问数据.
EnumerableThreadLocal的思想是在线程中执行本地操作,然后在主线程中减少结果。
tbb有一个类似的函数,名为可执行线程特定的,它的动机可以在https://oneapi-src.github.io/oneTBB/main/tbb_userguide/design_patterns/Divide_and_Conquer.html中找到
下面是一个尝试,模仿tbb代码,而不依赖于tbb。下面的代码的缺点是你在Windows上被限制为1088个键。

template <typename T>
    class EnumerableThreadLocal
    {

#if _WIN32 || _WIN64
        using tls_key_t = DWORD;
        void create_key() { my_key = TlsAlloc(); }
        void destroy_key() { TlsFree(my_key); }
        void set_tls(void *value) { TlsSetValue(my_key, (LPVOID)value); }
        void *get_tls() { return (void *)TlsGetValue(my_key); }
#else
        using tls_key_t = pthread_key_t;
        void create_key() { pthread_key_create(&my_key, nullptr); }
        void destroy_key() { pthread_key_delete(my_key); }
        void set_tls(void *value) const { pthread_setspecific(my_key, value); }
        void *get_tls() const { return pthread_getspecific(my_key); }
#endif
        std::vector<std::pair<std::thread::id, std::unique_ptr<T>>> m_thread_locals;
        std::mutex m_mtx;
        tls_key_t my_key;

        using Factory = std::function<std::unique_ptr<T>()>;
        Factory m_factory;

        static auto DefaultFactory()
        {
            return std::make_unique<T alignas(hardware_constructive_interference_size)>();
        }

    public:

        EnumerableThreadLocal(Factory factory = &DefaultFactory ) : m_factory(factory)
        {
            create_key();
        }

        ~EnumerableThreadLocal()
        {
            destroy_key();
        }

        EnumerableThreadLocal(const EnumerableThreadLocal &other)
        {
            create_key();
            // deep copy the m_thread_locals
            m_thread_locals.reserve(other.m_thread_locals.size());
            for (const auto &pair : other.m_thread_locals)
            {
                m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
            }
        }

        EnumerableThreadLocal &operator=(const EnumerableThreadLocal &other)
        {
            if (this != &other)
            {
                destroy_key();
                create_key();
                m_thread_locals.clear();
                m_thread_locals.reserve(other.m_thread_locals.size());
                for (const auto &pair : other.m_thread_locals)
                {
                    m_thread_locals.emplace_back(pair.first, std::make_unique<T>(*pair.second));
                }
            }
            return *this;
        }

        EnumerableThreadLocal(EnumerableThreadLocal &&other) noexcept
        {
            // deep move
            my_key = other.my_key;
            // deep move the m_thread_locals
            m_thread_locals = std::move(other.m_thread_locals);
            other.my_key = 0;

        }

        EnumerableThreadLocal &operator=(EnumerableThreadLocal &&other) noexcept
        {
            if (this != &other)
            {
                destroy_key();
                my_key = other.my_key;
                m_thread_locals = std::move(other.m_thread_locals);
                other.my_key = 0;
            }
            return *this;
        }

        T *Get ()
        {
            void *v = get_tls();
            if (v)
            {
                return reinterpret_cast<T *>(v);
            }
            else
            {
                const std::scoped_lock l(m_mtx);
                for (const auto &[thread_id, uptr] : m_thread_locals)
                {
                    // This search is necessary for the case if we run out of TLS indicies in customer's process, and we do at least slow lookup
                    if (thread_id == std::this_thread::get_id())
                    {
                        set_tls(reinterpret_cast<void *>(uptr.get()));
                        return uptr.get();
                    }
                }

                m_thread_locals.emplace_back(std::this_thread::get_id(), m_factory());
                T *ptr = m_thread_locals.back().second.get();
                set_tls(reinterpret_cast<void *>(ptr));
                return ptr;
            }
        }

        T const * Get() const
        {
            return const_cast<EnumerableThreadLocal *>(this)->Get();
        }

        T & operator *()
        {
            return *Get();
        }

        T const & operator *() const
        {
            return *Get();
        }

        T * operator ->()
        {
            return Get();
        }

        T const * operator ->() const
        {
            return Get();
        }

        template <typename F>
        void Enumerate(F fn)
        {
            const std::scoped_lock lock(m_mtx);
            for (auto &[thread_id, ptr] : m_thread_locals)
                fn(*ptr);
        }
    };

以及一组测试案例来展示它是如何工作的

#include <thread>
#include <string>
#include "gtest/gtest.h"
#include "EnumerableThreadLocal.hpp"

TEST(EnumerableThreadLocal, BasicTest)
{
    const int N = 10;
    v31::EnumerableThreadLocal<std::string> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { *tls = "Thread " + std::to_string(i); });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    std::vector<std::string> expected;
    tls.Enumerate([&](std::string &s)
                  { expected.push_back(s); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
    }

}

// Create a non copyable type, non moveable type
struct NonCopyable
{
    int i=0;
    NonCopyable() = default;
    NonCopyable(const NonCopyable &) = delete;
    NonCopyable(NonCopyable &&) = delete;
    NonCopyable &operator=(const NonCopyable &) = delete;
    NonCopyable &operator=(NonCopyable &&) = delete;
};

// A test to see if we can insert non moveable/ non copyable types to the tls
TEST(EnumerableThreadLocal, NonCopyableTest)
{
    const int N = 10;
    v31::EnumerableThreadLocal<NonCopyable> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { tls->i=i; });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    std::vector<int> expected;
    tls.Enumerate([&](NonCopyable &s)
                  { expected.push_back(s.i); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], i);
    }
}

const int N = 10;
v31::EnumerableThreadLocal<std::string> CreateFixture()
{
    v31::EnumerableThreadLocal<std::string> tls;

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { *tls = "Thread " + std::to_string(i); });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    return tls;
}

void CheckFixtureCopy(v31::EnumerableThreadLocal<std::string> & tls)
{
    std::vector<std::string> expected;

    tls.Enumerate([&](std::string &s)
                    { expected.push_back(s); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], "Thread " + std::to_string(i));
    }
}

void CheckFixtureEmpty(v31::EnumerableThreadLocal<std::string> & tls)
{
    std::vector<std::string> expected;

    tls.Enumerate([&](std::string &s)
                    { expected.push_back(s); });

    ASSERT_EQ(expected.size(), 0);
}

/// Test for copy construct of EnumerableThreadLocal
TEST(EnumerableThreadLocal, Copy)
{
    auto tls = CreateFixture();
    // Copy the tls
    auto tls_copy = tls;

    CheckFixtureCopy(tls_copy);
    CheckFixtureCopy(tls);
}

/// Test for move construct of EnumerableThreadLocal
TEST(EnumerableThreadLocal, Move)
{
    auto tls = CreateFixture();
    // Copy the tls
    auto tls_copy = std::move(tls);

    CheckFixtureCopy(tls_copy);
    CheckFixtureEmpty(tls);
}

/// Test for copy assign of EnumerableThreadLocal
TEST(EnumerableThreadLocal, CopyAssign)
{
    auto tls = CreateFixture();
    // Copy the tls
    v31::EnumerableThreadLocal<std::string> tls_copy;
    CheckFixtureEmpty(tls_copy);
    tls_copy = tls;

    CheckFixtureCopy(tls_copy);
    CheckFixtureCopy(tls);
}   

/// Test for move assign of EnumerableThreadLocal
TEST(EnumerableThreadLocal, MoveAssign)
{
    auto tls = CreateFixture();
    // Copy the tls
    v31::EnumerableThreadLocal<std::string> tls_copy;
    CheckFixtureEmpty(tls_copy);
    tls_copy = std::move(tls);

    CheckFixtureCopy(tls_copy);
    CheckFixtureEmpty(tls);
}

//class with no default constructor
struct NoDefaultConstructor
{
    int i;
    NoDefaultConstructor(int i) : i(i) {}
};

// Test for using objects with no default constructor
TEST(EnumerableThreadLocal, NoDefaultConstructor)
{
    const int N = 10;
    v31::EnumerableThreadLocal<NoDefaultConstructor> tls([]{return std::make_unique<NoDefaultConstructor>(0);});

    // Create N threads and assign a string including the thread ID to the tls
    std::vector<std::thread> threads;
    for (int i = 0; i < N; ++i)
    {
        threads.emplace_back([&tls, i]()
                             { tls->i = i; });
    }

    // Wait for all threads to finish
    for (auto &thread : threads)
        thread.join();

    // enumerate and sort and verify
    std::vector<int> expected;  
    tls.Enumerate([&](NoDefaultConstructor &s)
                    { expected.push_back(s.i); });

    // Sort the expected vector
    std::sort(expected.begin(), expected.end());

    // check the expected vector
    for (int i = 0; i < N; ++i)
    {
        ASSERT_EQ(expected[i], i);
    }

}

相关问题