numpy 如何优化此循环

hc2pp10m  于 2023-01-26  发布在  其他
关注(0)|答案(1)|浏览(123)

输入是观察列表,每个观察是椭圆的固定大小集合(每个椭圆由7个参数表示)。
输出是一个图像列表,一个图像对应一个观察结果,我们基本上是把观察结果的椭圆放到全白色图像上。如果有几个椭圆重叠,我们就放rgb值的平均值。

n, m = size of image in pixels, image is represented as (n, m, 3) numpy array (3 because of RGB coding)
N = number of elipses in every individual observation

xx, yy = np.mgrid[:n, :m]

def elipses_population_to_img_population(elipses_population):
        population_size = elipses_population.shape[0]
        img_population = np.empty((population_size, n, m, 3))
        for j in range(population_size):
            imarray = np.empty((N, n, m, 3))
            imarray.fill(np.nan)
            for i in range(N):
                x = elipses_population[j, i, 0]
                y = elipses_population[j, i, 1]
                R = elipses_population[j, i, 2]
                G = elipses_population[j, i, 3]
                B = elipses_population[j, i, 4]
                a = elipses_population[j, i, 5]
                b = elipses_population[j, i, 6]
                xx_centered = xx - x
                yy_centered = yy - y
                elipse = (xx_centered / a)**2 + (yy_centered / b)**2 < 1
                imarray[i, elipse, :] = np.array([R, G, B])
            means_img = np.nanmean(imarray, axis=0)
            means_img = np.nan_to_num(means_img, nan=255)
            img_population[j, :, :, :] = means_img
        return img_population

代码工作正常,但我正在寻找优化建议。我在我的代码中运行了很多次,所以每一个小的改进都会很有帮助。

vof42yt1

vof42yt11#

def elipses_population_to_img_population(elipses_population):
  population_size, N, _ = elipses_population.shape
  xx_centered = xx[np.newaxis, np.newaxis, :, :] - elipses_population[:,:,0, np.newaxis, np.newaxis]
  yy_centered = yy[np.newaxis, np.newaxis, :, :] - elipses_population[:,:,1, np.newaxis, np.newaxis]
  a = elipses_population[:,:,5, np.newaxis, np.newaxis]
  b = elipses_population[:,:,6, np.newaxis, np.newaxis]
  elipse = ((xx_centered / a)**2 + (yy_centered / b)**2) < 1
  img_population = elipses_population[:,:, 2:5, np.newaxis, np.newaxis]
  img_population[elipse] = np.nan
  img_population = np.nanmean(img_population, axis=1)
  img_population = np.nan_to_num(img_population, nan=255)
  return img_population

这避免了使用显式的for循环,而是使用了numpy广播和向量化操作,这应该会带来性能提升。

相关问题