我在Python+Numpy和Matlab中运行相同的测试代码,看到Matlab代码快了一个数量级,我想知道Python代码的瓶颈是什么,如何加速。
我使用Python+Numpy运行下面的测试代码(最后一部分是性能敏感部分):
# Packages
import numpy as np
import time
# Number of possible outcomes
num_outcomes = 20
# Dimension of the system
dim = 50
# Number of iterations
num_iterations = int(1e7)
# Possible outcomes
outcomes = np.arange(num_outcomes)
# Possible transition matrices
matrices = [np.random.rand(dim, dim) for k in outcomes]
matrices = [mat/np.sum(mat, axis=0) for mat in matrices]
# Initial state
state = np.random.rand(dim)
state = state/np.sum(state)
# List of samples
samples = np.random.choice(outcomes, size=(num_iterations,))
samples = samples.tolist()
# === PERFORMANCE-SENSITIVE PART OF THE CODE ===
# Update the state over all iterations
start_time = time.time()
for k in range(num_iterations):
sample = samples[k]
matrix = matrices[sample]
state = np.matmul(matrix, state)
end_time = time.time()
# Print the execution time
print(end_time - start_time)
然后,我使用Matlab运行一个等效代码(最后一部分是性能敏感部分):
% Number of possible outcomes
num_outcomes = 20;
% Number of dimensions
dim = 50;
% Number of iterations
num_iterations = 1e7;
% Possible outcomes
outcomes = 1:num_outcomes;
% Possible transition matrices
matrices = rand(num_outcomes, dim, dim);
matrices = matrices./sum(matrices,2);
matrices = num2cell(matrices,[2,3]);
matrices = cellfun(@shiftdim, matrices, 'UniformOutput', false);
% Initial state
state = rand(dim,1);
state = state./sum(state);
% List of samples
samples = datasample(outcomes, num_iterations);
% === PERFORMANCE-SENSITIVE PART OF THE CODE ===
% Update the state over all iterations
tic;
for k = 1:num_iterations
sample = samples(k);
matrix = matrices{sample};
state = matrix * state;
end
toc;
Python代码始终比Matlab代码慢一个数量级,我不确定为什么。
你知道从哪里开始吗?
我使用Python 3.10解释器和Numpy 1.22.4运行Python代码。我使用Matlab R2022 a运行Matlab代码。这两个代码都在联想T14 ThinkPad上的Windows 11 Pro 64位上运行,处理器如下:
第11代英特尔(R)酷睿(TM)i7- 1165G7@2.80GHz,2803 Mhz,4个内核,8个逻辑处理器
编辑1:我做了一些额外的测试,看起来罪魁祸首是在低矩阵大小时Python特有的某种类型的常量开销:
正如hpaulj和MSS所建议的,这可能意味着JIT编译器可以解决其中的一些问题,我将在不久的将来尽我所能尝试这一点。
1条答案
按热度按时间b1zrtrql1#
在环中,主要的阻碍是
np.matmul(matrix,state)
。如果我们展开这个循环:
没有明显的矢量化方法来以非循环方式执行循环
np.matmul
。更好的方法是在log_2(n)循环中完成。
使用numba jit可以进一步减少时间。