我有一个tf.keras
模型,它输入一个形状为(batch_size, )
的Tensor,输出另一个形状相同的Tensor,索引为i
的结果不依赖于索引为j != i
的任何输入。
我想把这个模型应用到任何形状的Tensor(dim1, dim2, ..., dimn)
上。理论上这应该是可能的,但实际上tensorflow 拒绝处理任何输入形状超过1维的东西。什么是绕过这个问题的最好的解决方案?我已经看过tf.map_fn
,但当递归使用时可能会变得复杂。我忽略了什么更简单的方法?
1条答案
按热度按时间anauzrmj1#
最后我这样解决了它:
当然,您可以将其推广到模型采用多于一维的输入的情况。