c++ 在CUDA的Thrust库中,无法使用箭头运算符区分device_ptr

yzckvree  于 2023-02-14  发布在  其他
关注(0)|答案(2)|浏览(215)

在尝试将我的C++光线跟踪代码转换为CUDA时,我无法遵从device_ptr的device_reference,这是在遍历device_ptr的device_vector时无意中创建的。

class hittable_list : public hittable {
public:
    __device__ hittable_list() {}
    __device__ hittable_list(device_ptr<hittable> object) { add(object); }

    __device__ void clear() { objects.clear(); }
    __device__ void add(device_ptr<hittable> object) { objects.push_back(object); }

    __device__ virtual bool hit(const ray& r, float t_min, float t_max, hit_record& rec) const override;

public:
    device_vector<device_ptr<hittable>> objects;
};

我希望在循环此向量时,会收到其中的device_ptrs

__device__ bool hittable_list::hit(const ray& r, float t_min, float t_max, hit_record& rec) const {
    hit_record temp_rec; // temp_rec is used to store the hit_record of all objects
    bool hit_anything = false;  // hit_anything is used to check if any object is hit
    float total_prob = 1.0;  // total_prob is used to store the total probability of the transmission

    for (const auto object: objects) { // loop through all objects
        if (object->hit(r, t_min, t_max, temp_rec)) {
            hit_anything = true;
            total_prob *= temp_rec.trans_prob; // update the total_prob
            temp_rec.trans_prob = total_prob;
            rec = temp_rec;
        }
    }
    return hit_anything;

但是,当我在对象上使用箭头操作符时,我得到以下错误:

error: operator -> or ->* applied to "const thrust::device_reference<const thrust::device_ptr<hittable>>" instead of to a pointer type

那么device_ptr是如何变成device_reference的呢?我又是如何获得device_ptr的呢?文档中确实提到“device_reference不打算直接使用;类似地,获取device_reference的地址将生成device_ptr”((device_reference的文档)但是,获取引用的地址对我来说没有任何意义,即使使用&object->hit()尝试也会导致相同的错误。
我尝试使用箭头运算符(*object).hit()的同义词,但错误仍然显示它仍然是设备引用

5t7ly7z5

5t7ly7z51#

虽然可以在设备代码中使用 Package 器,如thrust::device_ptrthrust::device_reference,但它们并不设计为将指针作为值类型保存。解引用device_ptr给出device_reference,其具有实现的典型值类型的许多运算符,但不具有指针类型的运算符,如进一步解引用,->或通常使用成员方法(在下面的示例中为.get())。因此,必须强制(但安全地)将引用强制转换为它的底层类型,即static_cast。我不知道为什么device_reference::operator value_type (void)没有出现在documentation中,但是从device_reference<T>T的转换/转换肯定是一个预期的特性。这创建了一个副本,而不是一个引用T&,这对T = device_ptr<...>来说很好,因为指针是轻量级的。
因此,虽然可以使用device_vector<device_ptr<T>>(见下面的示例),但这可能不是一个好主意。这些 Package 器的主要目的是避免在Thrust算法中使用显式执行策略。当用于设备代码时,无法向主机调度,因此设备 Package 器最多只能携带不必要的信息。
使用嵌套device_ptr的示例(即使它确实有效,也要避免这种情况):

#include <thrust/device_vector.h>
#include <thrust/for_each.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/execution_policy.h>

class Foo {
    int i{};

    public:
    Foo() = default;

    __host__ __device__
    Foo(int j) noexcept : i{j} {}

    __host__ __device__
    int get() const noexcept { return i; }
};

int main() {
    thrust::device_vector<Foo> foo(1, 42);
    thrust::device_vector<thrust::device_ptr<Foo>> bar(1, nullptr);
    
    auto foo_ptr = foo.data();     // gives device_ptr<Foo>
    auto foo_ptr_ptr = bar.data(); // gives device_ptr<device_ptr<Foo>>

    thrust::for_each(
        thrust::device, // this is optional here as device is the default when the given iterators allow for it, but I like to be explicit about it
        thrust::make_counting_iterator(0),
        thrust::make_counting_iterator(1),
        [foo_ptr, foo_ptr_ptr] __host__ __device__ (int idx) {
            foo_ptr_ptr[idx] = &foo_ptr[idx]; // device_reference<T>::operator=(T&) (with T = device_ptr<Foo>) is implemented
            printf(
                "%d\n", 
                static_cast<thrust::device_ptr<Foo>>(
                    foo_ptr_ptr[idx] // gives device_reference<device_ptr<Foo>>
                )->get() // device_ptr<T>::operator->() is implemented
            );
        });
    
    return 0;
}

如果要使用*解引用,则需要两次强制转换,即

static_cast<Foo>(*static_cast<device_ptr<Foo>>(foo_ptr_ptr[idx])).get()

因为X1 M15 N1 X不具有X1 M16 N1 X成员函数。
不使用 Package (这样做是为了提高可读性):

// ...same includes and Foo class as in above snippet...

int main() {
    thrust::device_vector<Foo> foo(1, 42);
    thrust::device_vector<Foo*> bar(1, nullptr);
    
    auto foo_ptr = thrust::raw_pointer_cast(foo.data());     // gives Foo*
    auto foo_ptr_ptr = thrust::raw_pointer_cast(bar.data()); // gives Foo**

    thrust::for_each(
        thrust::device, // this is optional here as device is the default when the given iterators allow for it, but I like to be explicit about it
        thrust::make_counting_iterator(0),
        thrust::make_counting_iterator(1),
        [foo_ptr, foo_ptr_ptr] __host__ __device__ (int idx) {
            foo_ptr_ptr[idx] = &foo_ptr[idx];
            printf("%d\n", foo_ptr_ptr[idx]->get());
        });
    
    return 0;
}
kx5bkwkv

kx5bkwkv2#

根据迭代器和静态调度,您应该为此使用 raw_pointer_cast

hittable *pHittable = thrust::raw_pointer_cast(object);
if (pHittable->hit(r, t_min, t_max, temp_rec)) {

相关问题