在循环中通过索引数组索引numpy数组

lqfhib0f  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(117)

我有一个向量,我想批量 Shuffle 。
我的想法是把它重新塑造成2D数组,每行作为一个批次。
然后我自己 Shuffle 。
这是一个很好的例子

# shuffle the matrix
mat_size = (8, 8)
row_size = 4

# generate the row and column indices
# shuffle the column
col_idx = np.arange(mat_size[0] * mat_size[1], dtype = np.int32) 

tmp_mat = np.reshape(col_idx, (-1, row_size))

for row in tmp_mat:
  idx = np.random.choice(row_size, size = row_size, replace = False)
  row[idx] = row # in place on col_idx

tmp_mat

我得到的结果是:

array([[ 1,  0,  0,  0],
       [ 4,  5,  6,  4],
       [ 8,  9,  9,  8],
       [12, 12, 14, 12],
       [16, 17, 17, 19],
       [21, 20, 20, 21],
       [25, 26, 24, 24],
       [28, 29, 28, 31],
       [33, 34, 32, 32],
       [36, 36, 38, 38],
       [40, 41, 41, 43],
       [44, 44, 46, 46],
       [48, 48, 50, 48],
       [53, 52, 53, 52],
       [56, 57, 56, 57],
       [61, 60, 62, 60]])

问题是这些行不是输入行的混洗版本。
可以看出原始数组每行具有唯一值。我不知道这是怎么回事。
我发现如果我用row[:] = row[idx]替换row[idx] = row,它就能工作。
有什么解释吗
为什么任务没有我期望的那样成功?

cu6pst1q

cu6pst1q1#

您的问题是由于在引用该行时尝试修改该行。因此,在内部,当您访问右侧的项时,if可能已经被替换。
如果您使用副本,则不会出现问题:

row[idx] = row.copy()

也不是,正如你所发现的那样:

row[:] = row[idx]

它也使用了一个副本。
请注意,您可以使用this recipe以向量方式重排数组:

def scramble(a, axis=-1):
    """
    Return an array with the values of `a` independently shuffled along the
    given axis
    """ 
    b = a.swapaxes(axis, -1)
    n = a.shape[axis]
    idx = np.random.choice(n, n, replace=False)
    b = b[..., idx]
    return b.swapaxes(axis, -1)

out = scramble(tmp_mat)

或者,如果列数有限,则相对有效:

out = tmp_mat[np.arange(tmp_mat.shape[0])[:,None],
              np.argsort(np.random.random(tmp_mat.shape))]

numpy.random.Generator.permuted

rng = np.random.default_rng()
out = rng.permuted(tmp_mat, axis=-1)

计时

(on提供的示例):

# loop
164 µs ± 5.96 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# scramble
13.5 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

# argsort
6.01 µs ± 511 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

# permuted
2.21 µs ± 82.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

在10k行上:

# loop
100 ms ± 1.84 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# scramble
29.7 µs ± 673 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# argsort
777 µs ± 61.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

# permuted
511 µs ± 17.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

相关问题