tensorflow 如何使用TF_VARIANT与C API?

ws51t4hk  于 4个月前  发布在  其他
关注(0)|答案(6)|浏览(95)

你好,

目前,我们可以使用 TF_FLOAT 数据类型创建或获取一个 TF_Tensor,然后通过 TF_TensorData 的指针操作原始数据缓冲区。例如:

TF_Tensor* output = TF_AllocateOutput(xxx, TF_FLOAT, xxx);
float* output_raw_buffer = reinterpret_cast<float*>(TF_TensorData(output));
// do some calculation on output_raw_buffer

但是对于数据类型 TF_VARIANT,我们不能像浮点类型那样直接操作它。我们有关于如何使用它的示例吗?
谢谢!

ifmq2ha2

ifmq2ha21#

@yanzhang-dev

您可以参考这个c_api,希望对您有所帮助。谢谢!

ffx8fchx

ffx8fchx2#

你好@UsharaniPagadala
你提供的测试用例是解析Tensor的元数据,如形状和dtype。
是的,我们也可以创建一个TF_VARIANT Tensor并获取这些元数据。但是如果我们想要获取Tensor的元素,例如:

  1. 如果Tensor的元素是TF_FLOAT,我们可以这样解析原始缓冲区。
  2. 但是如果Tensor的元素是TF_Variant,我们不能像上面那样通过TF_TensorData来解析原始缓冲区。我们应该怎么做呢?
    谢谢!
5cg8jx4n

5cg8jx4n4#

你好,@yanzhang-dev,
抱歉延迟了!
TF_VARIANT / DT_VARIANT 是一种用于封装C++数据结构的数据类型。通常情况下,你可以将DT_VARIANTTensor视为字节块的 Package 器,其值将根据其所在内核进行解释。
例如:

由于 Variant::get() 是一个模板函数,我们不能直接将其暴露给C API。但我认为我们可以为C API创建一个返回类型擦除的指针到二进制块,例如,在 TensorInterface::Data 下面添加 TensorInterface::VariantData

void* TensorInterface::VariantData() const {
  return tensor_.scalar<Variant>()().get<void>();
}

然后你的内核将接收这个void指针,并将二进制块转换为C对象。这需要插件具有与核心TensorFlow完全相同的头文件。C类也没有版本保证。所以我不确定尝试支持 DT_VARIANT 是否值得。

5cnsuln7

5cnsuln75#

@penpornk,有没有一个Python级别的操作列表,最终使用了变体数据类型?例如,tensorflow.keras.layers.LSTM使用了TensorList。不确定TensorList是否来自父类tf.keras.layers.RNN。如果是前者,那么所有的RNN对插件都不友好。我能想到的一个解决这个问题的方法是用不使用变体类型的操作实现替换它吗?对于LSTM或RNN有这样的替代方案吗?

x3naxklr

x3naxklr6#

请告诉我在哪里可以找到源代码中 tensorflow.keras.layers.LSTM --> TensorList --> DT_VARIANT 的使用情况?

相关问题