numpy 按索引对ndarray中的元素进行分组

pgky5nke  于 2023-02-04  发布在  其他
关注(0)|答案(1)|浏览(150)

我有一个包含1000张图像的图像数据集,我已经为它创建了嵌入。每个嵌入(每个图像的512个嵌入,具有256维矢量)是一个形状为(512,256)的ndarray,因此总的数组形状将是(1000,512,256)。
现在,从每个图像(1000),我想创建一组观测,用于第一次嵌入,512个可用的,并且从每个图像收集这个嵌入,然后我想这样做,用于第二次嵌入,第三次,第四次,直到第512次。
我该如何着手创建这些组?

6yt4nkrj

6yt4nkrj1#

您可以通过以下方式实现这一点:

groups = []

for i in range(512):
    # Select the i-th embedding from each image
    group = embeddings[:, i, :]
    groups.append(group)

groups = np.array(groups)

另一个优化的解决方案:

groups = np.array([embeddings[:, i, :] for i in range(512)])
groups = np.transpose(groups, (1, 0, 2))

相关问题