如何使用scipy.integrate.solve_ivp以向量化方式处理具有耦合微分方程的numpy数组输入

bvhaajcl  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(113)

假设我想解一列方程,但它们并不耦合,只是由numpy数组定义的不同。
第一个月
其中c只是numpy数组中给定的一个已知常数,例如c = np.array([1, 2])。这意味着我需要求解两个方程,可能还需要不同的初始值。当然,这可以很容易地通过解析求解。
y(t) = t^2 + c*t + C
其中C是另一个依赖于y(0)的初始值的常数。例如,第一个方程有C = y(0) = 0,第二个方程有C = y(0) = 1。然后我可以得到精确解:

y =  t^2 + 1*t
y =  t^2 + 2*t + 1

字符串
如果我想使用scipy.integrate.solve_ivp来获取t=1处的值,那么我可以“假装”它们是耦合的。

def test_fun(t, y):
    c = np.array([1, 2])
    dy1 = 2*t + c
    return dy1
sol = solve_ivp(test_fun, [0, 1], np.array([0, 1]))
y = sol.y[:, -1]


这实际上是给我的y = np.array([2., 4.]),它与给予的解析值相匹配。
但是现在假设我有一个带有常数的方程列表,不同的是一个numpy数组,但是每个方程都由3个耦合的微分方程组成。让我们仍然使用一个玩具示例,它们不是真正的“耦合”,而是只需要在函数中返回多个值,就像你在求解耦合微分方程时所做的那样。

def test_fun(t, y):
    c = np.array([1, 2])
    dy1 = 2*t + c
    dy2 = 3*t + c
    dy3 = 4*t + c
    return dy1, dy2, dy3

sol = solve_ivp(test_fun, [0, 1], np.array([[1, 2], [2, 2], [3, 2]]))


但是现在它会抱怨ValueError: y0 must be 1-dimensional. Flatten y0似乎没有意义,因为我希望函数的输入是一个大小为2的数组,但是对于每个元素,它将返回一个由3个耦合方程组成的系统。如果我只是flatten,它将假设它是一个由6个耦合方程组成的系统的输入,然后抱怨ValueError: could not broadcast input array from shape (3,2) into shape (6,)
当然,我可以只使用for循环,但一般的想法是避免它,因为它在python中非常慢,而且我的c数组非常大。

pgx2nnw8

pgx2nnw81#

  • “展平y0似乎没有意义,因为我希望函数的输入是一个大小为2的数组,但对于每个元素,它将返回一个由3个耦合方程组成的系统。如果我只是展平,它将假设它是一个由6个耦合方程组成的系统的输入,然后抱怨ValueError:could not broadcast input array from shape(3,2)into shape(6,)"*

在将y0传递给求解器之前,必须对其进行展平。然后test_fun将得到展平的向量,因此在test_fun中,您将对y进行整形,使用数组进行计算,然后在从test_fun返回导数数组之前对其进行展平。(在您的简单示例中,没有使用y,因此您不必对其进行整形。)当求解器返回时,您还必须对结果进行整形,使其看起来像一个数组的数组。
我创建了一个名为odeintw的软件包,它可以为你做这件事,但它使用的是scipy.integrate.odeint,而不是solve_ivp。下面是如果你使用odeintw,你的脚本可能会是什么样子:

import numpy as np
from odeintw import odeintw

def test_fun(t, y):
    c = np.array([1, 2])
    dy1 = 2*t + c
    dy2 = 3*t + c
    dy3 = 4*t + c
    return dy1, dy2, dy3

n = 250
t = np.linspace(0, 1, n)
y0 = np.array([[1, 2], [2, 2], [3, 2]])
sol = odeintw(test_fun, y0, t, tfirst=True)
# The numerical solution `sol` is an array with shape (n, 3, 2).

字符串
您需要运行一些测试,看看这种方法是否比在单个系统上运行Python循环更快。(为了提高性能,您可以尝试实现odeintDfun参数。请参阅odeintw文档字符串了解如何实现它。)

相关问题