过滤numpy.dstack

ecfsfe2w  于 2023-01-13  发布在  其他
关注(0)|答案(1)|浏览(84)

我有一个这样的dstack:

import numpy as np
a = np.array((1,2,6))
b = np.array((2,3,4))
c = np.array((8,3,0))
stack = np.dstack((a,b,c))
print(stack)
#[[[1 2 8]
  #[2 3 3]
  #[6 4 0]]]

我想过滤掉元素2小于1的列表。
大概是这样的

new_list = []

for i in stack:
    for d in i[:,2]:
        if d>=1:
            new_list.append(d)
print(new_list) # [8,3]

这样做只添加了2个元素,但我希望有所有的行,如下所示:

#[[[1 2 8]
  #[2 3 3]]]

如果I append(i),结果也不是期望的结果。

ql3eal8s

ql3eal8s1#

你不需要一个循环,你可以用切片来完成

print(stack[stack[:,2] >= 1])

产出

[[1 2 8]
 [2 3 3]]

如果您需要它作为

[[[1 2 8]]
 [[2 3 3]]]

你可以把结果

stack = stack[stack[:,2] >= 1]
shape = stack.shape
print(stack[stack[:,2] >= 1].reshape((shape[0], 1, shape[1])))

相关问题