我想在下面一批维度为[batch,4,4] = [2,4,4]的采样数组中找到的位置,其中的是“1”。
import jax
import jax.numpy as jnp
a = jnp.array([[[0., 0., 0., 1.],
[0., 0., 0., 0.],
[0., 1., 0., 1.],
[0., 0., 1., 1.]],
[[1., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 1., 0., 1.]]])
我试着遍历批处理的维度(使用vmap),并使用jax函数查找坐标
b = jax.vmap(jnp.where)(a)
print('b', b)
但我得到了一个错误,我不知道如何修复:
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This Tracer was created on line /home/imi/Desktop/Backflow/backflow/src/debug.py:17 (<module>)
我希望得到以下输出:
b = [[[0,3], [2,1],[2,3],[3,2],[3,3]],
[[0,0],[0,2],[1,0],[3,1],[3,3]]
[x,y]坐标的第一行对应于第一批中存在“1”的位置,并且对应于第二批中的第二行。
1条答案
按热度按时间myzjeezk1#
像
vmap
这样的JAX转换要求数组的大小是静态的,因此没有办法精确地执行您所考虑的计算(因为1
条目的数量以及输出数组的大小是依赖于数据的)。但是如果你事先知道每批有五个条目,你可以做如下的事情:
如果您不知道 * 先验 * 有多少
1
条目,那么您有几个选择:一种是避免JAX转换,并在每个批处理上调用未转换的jnp.where
:一个二个一个一个
注意,对于这种情况,通常不可能将结果存储在单个数组中,因为每个批处理中可能有不同数量的
1
条目,并且JAX不支持不规则数组。另一个选项是将
size
设置为某个最大值,并输出填充的结果:有了填充的结果,您就可以编写代码的其余部分来预测这些填充的值。