numpy 给定一批样本,查找x,y位置- python,jax

tyky79it  于 2023-01-17  发布在  Python
关注(0)|答案(1)|浏览(116)

我想在下面一批维度为[batch,4,4] = [2,4,4]的采样数组中找到的位置,其中的是“1”。

import jax
import jax.numpy as jnp

a = jnp.array([[[0., 0., 0., 1.],
                [0., 0., 0., 0.],
                [0., 1., 0., 1.],
                [0., 0., 1., 1.]],
             
               [[1., 0., 1., 0.],
                [1., 0., 0., 0.],
                [0., 0., 0., 0.],
                [0., 1., 0., 1.]]])

我试着遍历批处理的维度(使用vmap),并使用jax函数查找坐标

b = jax.vmap(jnp.where)(a)
print('b', b)

但我得到了一个错误,我不知道如何修复:

The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This Tracer was created on line /home/imi/Desktop/Backflow/backflow/src/debug.py:17 (<module>)

我希望得到以下输出:

b = [[[0,3], [2,1],[2,3],[3,2],[3,3]],

     [[0,0],[0,2],[1,0],[3,1],[3,3]]

[x,y]坐标的第一行对应于第一批中存在“1”的位置,并且对应于第二批中的第二行。

myzjeezk

myzjeezk1#

vmap这样的JAX转换要求数组的大小是静态的,因此没有办法精确地执行您所考虑的计算(因为1条目的数量以及输出数组的大小是依赖于数据的)。
但是如果你事先知道每批有五个条目,你可以做如下的事情:

from functools import partial
indices = jax.vmap(partial(jnp.where, size=5))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
  [2 1]
  [2 3]
  [3 2]
  [3 3]]

 [[0 0]
  [0 2]
  [1 0]
  [3 1]
  [3 3]]]

如果您不知道 * 先验 * 有多少1条目,那么您有几个选择:一种是避免JAX转换,并在每个批处理上调用未转换的jnp.where
一个二个一个一个
注意,对于这种情况,通常不可能将结果存储在单个数组中,因为每个批处理中可能有不同数量的1条目,并且JAX不支持不规则数组。
另一个选项是将size设置为某个最大值,并输出填充的结果:

max_size = a[0].size  # size of slice is the upper bound
fill_value = a[0].shape  # fill with out-of-bound indices
indices = jax.vmap(partial(jnp.where, size=max_size, fill_value=fill_value))(a)
print(jnp.stack(indices, axis=2))
[[[0 3]
  [2 1]
  [2 3]
  [3 2]
  [3 3]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]]

 [[0 0]
  [0 2]
  [1 0]
  [3 1]
  [3 3]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]
  [4 4]]]

有了填充的结果,您就可以编写代码的其余部分来预测这些填充的值。

相关问题