你好,
目前,我们可以使用 TF_FLOAT
数据类型创建或获取一个 TF_Tensor
,然后通过 TF_TensorData
的指针操作原始数据缓冲区。例如:
TF_Tensor* output = TF_AllocateOutput(xxx, TF_FLOAT, xxx);
float* output_raw_buffer = reinterpret_cast<float*>(TF_TensorData(output));
// do some calculation on output_raw_buffer
但是对于数据类型 TF_VARIANT,我们不能像浮点类型那样直接操作它。我们有关于如何使用它的示例吗?
谢谢!
6条答案
按热度按时间ifmq2ha21#
@yanzhang-dev
您可以参考这个
c_api
,希望对您有所帮助。谢谢!ffx8fchx2#
你好@UsharaniPagadala
你提供的测试用例是解析Tensor的元数据,如形状和dtype。
是的,我们也可以创建一个TF_VARIANT Tensor并获取这些元数据。但是如果我们想要获取Tensor的元素,例如:
TF_FLOAT
,我们可以这样解析原始缓冲区。TF_Variant
,我们不能像上面那样通过TF_TensorData
来解析原始缓冲区。我们应该怎么做呢?谢谢!
2jcobegt3#
@penpornk
5cg8jx4n4#
你好,@yanzhang-dev,
抱歉延迟了!
TF_VARIANT / DT_VARIANT 是一种用于封装C++数据结构的数据类型。通常情况下,你可以将DT_VARIANTTensor视为字节块的 Package 器,其值将根据其所在内核进行解释。
例如:
由于 Variant::get() 是一个模板函数,我们不能直接将其暴露给C API。但我认为我们可以为C API创建一个返回类型擦除的指针到二进制块,例如,在 TensorInterface::Data 下面添加
TensorInterface::VariantData
。然后你的内核将接收这个void指针,并将二进制块转换为C对象。这需要插件具有与核心TensorFlow完全相同的头文件。C类也没有版本保证。所以我不确定尝试支持
DT_VARIANT
是否值得。5cnsuln75#
@penpornk,有没有一个Python级别的操作列表,最终使用了变体数据类型?例如,tensorflow.keras.layers.LSTM使用了TensorList。不确定TensorList是否来自父类tf.keras.layers.RNN。如果是前者,那么所有的RNN对插件都不友好。我能想到的一个解决这个问题的方法是用不使用变体类型的操作实现替换它吗?对于LSTM或RNN有这样的替代方案吗?
x3naxklr6#
请告诉我在哪里可以找到源代码中
tensorflow.keras.layers.LSTM
-->TensorList
-->DT_VARIANT
的使用情况?