numpy 如何在python中求解非连续方程?

bis0qfac  于 2022-02-20  发布在  Python
关注(0)|答案(1)|浏览(153)

我有一个功能:

p = np.arange(0,1,0.01)
EU = (((-(1.3+1-0.5)**(1.2-1.5)/(1-1.5)))**-1)*p + (((-(0.1+1-0.5)**(1.2-1.5)/(1-1.5)))**-1)*(1-p)

我想解决这个问题,得到p的值,其中EU == 0.5。我的问题是,我手动检查了p值为0.4244,但在函数中,p以0.01的步长跳跃。
现在,我很好在EU最接近0.5的地方得到最接近的p值,但我担心代码找不到答案,因为p数组中没有一个值会为EU精确返回0.5。我该如何解决这个问题呢?

f2uvfpb9

f2uvfpb91#

你的函数是一个简单的线性函数,所以它可以通过一个简单的二分查找找到,但是如果函数是更复杂的,你需要真正计算梯度和runt他优化你可以使用 Torch 或jax或任何其他autograd工具来做

import torch
def func(p):
    return (((-(1.3+1-0.5)**(1.2-1.5)/(1-1.5)))**-1)*p + (((-(0.1+1-0.5)**(1.2-1.5)/(1-1.5)))**-1)*(1-p)

x = torch.tensor([0.1], requires_grad=True)
loss = 1
# 1e-7 example threshold for answer to stop
while loss > 1e-7:
    x.retain_grad()
    y = func(x)
    loss = (y - torch.tensor([0.5], requires_grad=True)) ** 2 
    # or use MSELoss from nn  
    loss.backward()
    x = x - 0.1 * x.grad # naive gradient decent with lr = 0.1

print(x.item())

0.42235034704208374

相关问题