tensorflow 模型在变形状Tensor中的应用

rkttyhzu  于 2022-11-25  发布在  其他
关注(0)|答案(1)|浏览(124)

我有一个tf.keras模型,它输入一个形状为(batch_size, )的Tensor,输出另一个形状相同的Tensor,索引为i的结果不依赖于索引为j != i的任何输入。
我想把这个模型应用到任何形状的Tensor(dim1, dim2, ..., dimn)上。理论上这应该是可能的,但实际上tensorflow 拒绝处理任何输入形状超过1维的东西。什么是绕过这个问题的最好的解决方案?我已经看过tf.map_fn,但当递归使用时可能会变得复杂。我忽略了什么更简单的方法?

anauzrmj

anauzrmj1#

最后我这样解决了它:

def apply_model(X: tf.Tensor, my_model: tf.keras.Model) -> tf.Tensor:
        """
        Apply a tf.keras.Model to a tensor of unknown dimensions.

        Args:
            X (tf.Tensor): The tensor containing the input.
            my_model (tf.keras.Model): The model you want to apply.

        Returns:
            tf.Tensor: A tensor of the same shape as X, where all values are
                a prediction by the model.
        """
        if len(X.shape) > 1:
            result = tf.stack(
                [apply_model(x) for x in tf.unstack(X, axis=-1)],
                axis=-1,
            )
        else:
            result = my_model(X)

        return result

当然,您可以将其推广到模型采用多于一维的输入的情况。

相关问题