python 如何使用Dask并行化迭代和更新numpy数组

mbzjlibv  于 2023-08-02  发布在  Python
关注(0)|答案(1)|浏览(123)

我有一个非常大的距离矩阵,我需要迭代每个值,并在条件为真时更新距离。
下面是我的Pandas/Numpy代码块:

dist_mat = pd.read_csv()
date_list = metadata['sample_collection_date'].values
numpy_arr = dist_mat.values
columns = dist_mat.columns.tolist()
col_index = 0
for i in range(dist_mat.shape[0]):
    numpy_arr[i][i] += 0.1
    for j in range(col_index):
        if abs(np.timedelta64(date_list[i] - date_list[j], 'D')) <= 14:
                numpy_arr[i][j] += 0.1
                numpy_arr[j][i] += 0.1
    col_index += 1

字符串
我试过使用Dask,但它并不比我使用Pandas/Numpy的速度快。我想知道什么是有助于并行处理此代码块的正确方法。

dist_mat = dd.read_csv(args.dist_file, sep='\t', skiprows=2, sample=10000000, assume_missing=True).set_index('#Sources')

date_list = metadata['sample_collection_date'].values
np_array = dist_mat.to_dask_array(lengths=True)
columns = dist_mat.columns.tolist()
col_index = 0
for i in range(dist_mat.shape[0].compute()):
    numpy_arr[i][i] += 0.1
    for j in range(col_index):
        if abs(np.timedelta64(date_list[i] - date_list[j], 'D')) <= 14:
                numpy_arr[i][j] += 0.1
                numpy_arr[j][i] += 0.1
    col_index += 1

x4shl7ld

x4shl7ld1#

问题似乎是您当前的Dask代码,您在循环中调用.compute()。这对于计算时间和存储器来说可能是非常昂贵的。
更有效的方法可能是使用Numba库,它可以JIT编译Python代码,并可以使用多个核心进行某些类型的操作-特别是ufuncs

举例如下

import numba
import numpy as np
import pandas as pd

@numba.njit(parallel=True)
def update_matrix(dist_mat, date_list, threshold=14, increment=0.1):
    N = dist_mat.shape[0]
    for i in numba.prange(N):
        dist_mat[i, i] += increment
        for j in range(i):
            if abs(np.timedelta64(date_list[i] - date_list[j], 'D')) <= threshold:
                dist_mat[i, j] += increment
                dist_mat[j, i] += increment
    return dist_mat

# Load your data
metadata = pd.read_csv('metadata.csv')
dist_mat = pd.read_csv('dist_mat.csv')

date_list = metadata['sample_collection_date'].values
dist_mat = dist_mat.values

# Update matrix
dist_mat = update_matrix(dist_mat, date_list)

字符串

相关问题