如何添加ARG_MAX操作在tensorflow lite安卓与GPU?

7gs2gvoe  于 2022-10-29  发布在  其他
关注(0)|答案(2)|浏览(178)

我正在使用tensorflow 2.5.0。我取一个四维Tensor,并在lambda函数中执行argmax:

def activate_postprocess(tensor):
  res = []
  for i in range(k):
    res.append(tf.argmax(tensor[i]))
  return res

post_processe_layer = tf.keras.layers.Lambda(activate_postprocess)
post_processed_output = post_processed_layer(model.outputs)

然后,我将TFLITE_BUILTINSSELECT_TF_OPS相加,进行转换:

converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                                             tf.lite.OpsSet.SELECT_TF_OPS
                                             ]
    tflite_model = converter.convert()

然后我在Android应用程序中运行,使用GPU代理

val options = Interpreter.Options()
val compatList = CompatibilityList()
val delegateOptions = compatList.bestOptionsForThisDevice
                gpuDelegate = GpuDelegate(delegateOptions)
                options.addDelegate(gpuDelegate)

运行模型后,我收到以下错误:

Failed for 'img.jpg' with error: Internal error: Failed to apply delegate: Following operations are not supported by GPU delegate:
    ARG_MAX: Operation is not supported.

在创建解释器之前,我也尝试了options.setUseNNAPI(true),但它没有帮助。
我该怎么解决呢?
我看到这个:https://www.tensorflow.org/lite/guide/ops_custom
但我不明白如何注册运营商(https://www.tensorflow.org/lite/guide/ops_custom#create_and_register_the_operator)
有argmax的例子吗?
谢谢

sd2nnvve

sd2nnvve1#

是的,我不记得我们的团队实现了ARG_MAX操作。我们拥有的最接近的操作可能是带索引的MAX_POOL。着色器代码应该在存储库中。
向委托添加自定义操作需要大量的管道。

  • 首先,你需要让tflite convert理解你的自定义操作(我从来没有在这个领域工作过,不知道它有多容易)。
  • 然后,您需要创建并注册操作,以便操作解析器理解自定义操作。如果您知道自己将完全在GPU端工作,则可以跳过操作的实现,但通常情况下,拥有参考CPU实现是有益的。
  • 之后,您需要更改GPU代理的模型构建器以识别新的操作。这涉及到解析操作并在GraphFloat32中表示它,GraphFloat32是GPU代理使用的内部数据结构。可能涉及到属性的定义。
  • 最后但并非最不重要的是,你需要实现着色器代码。如果你不精通计算着色器,这可能是一个主要的头痛。
pkwftd7m

pkwftd7m2#

@yonatanbitton NNAPI支持ArgMax。但是,tf.argmax的默认输出数据类型为int64NNAPI's argmax不支持此类型。如果int32足以满足您的使用要求,请使用tf.argmax(tensor[i], output_type=tf.dtypes.int32)

相关问题