在3D numpy数组上执行操作的快速而优雅的方式

gt0wga4j  于 2023-05-17  发布在  其他
关注(0)|答案(1)|浏览(147)

我有以下数据:

  • img:表示形状为(200,200,256)的高光谱图像的3D numpy阵列,即200 x200像素,每个像素在均匀间隔的波长(波段)处保持256个整数反射率值
  • reference_spectrum:一个numpy数组,包含256个整数反射率值(类似于单个像素)
  • dist:一个距离函数,接受两个相同长度的numpy数组并返回一个真实的

我想得到一个2D numpy数组result,形状为(200,200),使得result[i, j] == dist(img[i, j], reference_spectrum)
首先,我写了一个明显的嵌套for循环:

result = np.zeros((img.shape[0], img.shape[1]))
for i in range(img.shape[0]):
    for j in range(img.shape[1]):
        result[i, j] = dist(img[r, c], reference_spectrum)

这是缓慢的,看起来很丑。
然后我使用距离函数$dist(a,B)= \sum_{i=0}^{n}| a_i-b_i|我写了下面的代码:

temp = np.zeros(img.shape)
for k in range(img.shape[2]):
    temp[: , :, k] = abs(img[: , :, k] - reference_spectrum[k])
result = np.sum(temp, axis=2)

这更快,但仅适用于特定的距离函数。
我正在寻找一个快速和模块化(w.r.t.距离函数)解。

mefy6pfw

mefy6pfw1#

可以扩展参照的暗显,并在最后一个暗显上应用操作。例如,假设您想要mse错误:

import numpy as np
data = np.random.randn(200, 200, 256)
ref = np.random.randn(256)

# mse = np.mean(error**2)
error = (data - ref.reshape(1, 1, -1))
mse = np.mean(error**2, axis=-1)
print(mse.shape)

还有numpy.apply_along_axis,它提供了沿着一个轴应用func的可能性

相关问题