我有下面的python代码:
import matplotlib.pyplot as plt
import numpy as np
class bifurcation_diagram(object):
def __init__(self):
self.omega = []
self.theta = []
self.dt = (2 * np.pi / (2.0 / 3.0)) / 600
self.time = []
self.theta_in_bifurcation_diagram = []
self.F_D = np.arange(1.35,1.5,0.001)
self.theta_versus_F_D = []
def calculate(self):
l = 9.8
g = 9.8
q = 0.5
Omega_D = 2.0 / 3.0
for f_d in self.F_D:
self.omega.append([0])
self.theta.append([0.2])
self.time.append([0])
for i in range(600000):
k1_theta = self.dt * self.omega[-1][-1]
k1_omega = self.dt * ((-g / l) * np.sin(self.theta[-1][-1]) - q * self.omega[-1][-1] + f_d * np.sin(Omega_D * self.time[-1][-1]))
k2_theta = self.dt * (self.omega[-1][-1] + 0.5 * k1_omega)
k2_omega = self.dt * ((-g / l) * np.sin(self.theta[-1][-1] + 0.5 * k1_theta) - q * (self.omega[-1][-1] + 0.5 * k1_omega) + f_d * np.sin(Omega_D * (self.time[-1][-1] + 0.5 * self.dt)))
k3_theta = self.dt * (self.omega[-1][-1] + 0.5 * k2_omega)
k3_omega = self.dt * ((-g / l) * np.sin(self.theta[-1][-1] + 0.5 * k2_theta) - q * (self.omega[-1][-1] + 0.5 * k2_omega) + f_d * np.sin(Omega_D * (self.time[-1][-1] + 0.5 * self.dt)))
k4_theta = self.dt * (self.omega[-1][-1] + k3_omega)
k4_omega = self.dt * ((-g / l) * np.sin(self.theta[-1][-1] + k3_theta) - q * (self.omega[-1][-1] + k3_omega) + f_d * np.sin(Omega_D * (self.time[-1][-1] + self.dt)))
temp_theta = self.theta[-1][-1] + (1.0 / 6.0) * (k1_theta + 2 * k2_theta + 2 * k3_theta + k4_theta)
temp_omega = self.omega[-1][-1] + (1.0 / 6.0) * (k1_omega + 2 * k2_omega + 2 * k3_omega + k4_omega)
while temp_theta > np.pi:
temp_theta -= 2 * np.pi
while temp_theta < -np.pi:
temp_theta += 2 * np.pi
self.omega[-1].append(temp_omega)
self.theta[-1].append(temp_theta)
self.time[-1].append(self.dt * i)
for i in range(500,1000):
n = i * 600
self.theta_in_bifurcation_diagram.append(self.theta[-1][n])
self.theta_versus_F_D.append(f_d)
def show_results(self):
plt.plot(self.theta_versus_F_D,self.theta_in_bifurcation_diagram,'.')
plt.title('Bifurcation diagram' + '\n' + r'$\theta$ versus $F_D$')
plt.xlabel(r'$F_D$')
plt.ylabel(r'$\theta$ (radians)')
plt.xlim(1.35,1.5)
plt.ylim(1,3)
plt.show()
bifurcation = bifurcation_diagram()
bifurcation.calculate()
bifurcation.show_results()
我想把它做得更紧凑,做一个彩色的图,沿着提高它的效率。在这方面的任何帮助都将是真正有益的。
我期望代码运行得快,但它需要超过15分钟的运行时间。
1条答案
按热度按时间7gcisfzg1#
Python不是为这样的代码而设计的,因为它通常使用CPython进行解释(这意味着要为粘合代码和脚本完成)和CPython执行(几乎)没有对代码进行优化,因此重复的表达式会被重新计算。加快此类代码速度的常用方法是使用Numpy函数对大型数组进行操作(不是列表)。这被称为矢量化。也就是说,这段代码非常昂贵,使用Numpy可能还不够,在这里使用Numpy也是一个不可忽视的工作。另一种解决方案是使用Numbaso编译这段代码注意,Numba旨在加速主要使用Numpy和基本数值计算的代码(而不是像绘图这样的通用代码)。
首先要做的是摆脱列表并将其替换为Numpy数组。数组可以预先分配合适的大小,以避免对
append
的昂贵调用。然后,计算功能应该从类中移走,以便Numba易于使用。然后,两个while
循环可以替换为np.remainder
。最后,你可以使用多个线程并行计算数组的每一行。下面是生成的代码(几乎没有测试过):
在我的6核i5- 9600 KF处理器上,这个代码需要1.07秒,而不是最初的19分20秒!因此,这个新代码大约快150倍!
此外,我还发现生成的代码更容易阅读。
我建议你检查一下结果,虽然它们乍一看很好。事实上,这是我到目前为止得到的结果图: