为什么我的简单MATLAB梯度下降线性回归不工作

x8diyxa7  于 2023-10-23  发布在  Matlab
关注(0)|答案(1)|浏览(179)

我开始学习线性回归。我想自己实现梯度下降。我写了下面的代码。

%% Linear regression

close all;

dataset =load('accidents');
data = dataset.hwydata;
x = data(:,14);
y  =data(:,4);
%% Gradient descent
% We want to minimize a cost function and GD achieves that iteratively.
% J(w,b) =(y-y_est)^2

w = 0;
b = 0;
alpha =.000001; % I tried various alphas like 0.01, .1 etc. (Not working)
for i =1: 100
    y_est = w*x + b;
    J = mean((y-y_est).^2)
    temp_w = w + alpha*(mean(x.*(y-w*x-b)));
    temp_b = b + alpha*(mean(y-w*x-b));
    w =temp_w
    b =temp_b
end

出于某种原因,它不起作用。我的算法好像没有收敛。
我希望算法能很好地收敛,因为均方误差成本函数是凸的。

toe95027

toe950271#

简短的回答:你需要沿梯度方向沿着进行某种线性搜索。
正如我在我的询问中提到的,大的x和y值会导致算法发散,这就是在你的程序中发生的事情,x值高达大约。3e7和偏导数w.r.t. w约为1e10。
对于一个非常粗糙和简单的线搜索:将当前优值函数(以下代码中的Jcurrent)与使用temp_wtemp_b和当前alpha(以下代码中的Jnew)计算的优值函数进行比较。
如果Jnew >= Jcurrent将alpha减小某个因子,然后使用新的alpha和当前梯度重新计算temp_wtemp_b。使用更新的temp_wtemp_b重新计算Jnew。重复直到Jnew < Jcurrent
在此行搜索之后,您至少有两个选择:1)将alpha重置为初始值(代码中的alphaOrig)或2)保持当前alpha。
请注意,这种线性搜索距离最优搜索 * 非常远 *。它只搜索价值函数的减少,并接受它,而不管减少有多小。这导致收敛缓慢。如果你想让我建议更好的线性搜索方法,请告诉我。
江淮

%% Linear regression

clear; 
% Delete all figures
figureList = findobj('type', 'figure');
if ~isempty(figureList)
    delete(figureList);
end
    alphaOrig = 1e-4;
    dataset =load('accidents');
    data = dataset.hwydata;
    x = data(:,14);
    y  =data(:,4);
% ..Check the ranges of x and y    
    fprintf(1, 'max(x) = %g\tmax(y) = %g\n', max(abs(x)), max(abs(y)) );
%% Gradient descent
% We want to minimize a cost function and GD achieves that iteratively.
% J(w,b) =(y-y_est)^2

    w = 0;
    b = 0;
    alpha = alphaOrig;
% ..Initial value of merit function
    Jcurrent = mean((y - w*x - b).^2);
    for i =1: 100
        y_est = w*x + b;
        J = mean((y-y_est).^2);
    % ..Components of the gradient (actually, minus gradient)
        dJdw = mean(x.*(y-w*x-b));
        dJdb = mean(y-w*x-b);
        fprintf(1, '%d: J(%g,%g) = %.8g\t', i, w,b,Jcurrent);
        fprintf(1, 'dJ/dw = %g; dJ/db = %g\n', dJdw, dJdb);
    % ..Crude line search
        while true % Loop not in the original at stackoverflow
            temp_w = w + alpha*dJdw;
            temp_b = b + alpha*dJdb;
            Jnew = mean((y - temp_w*x - temp_b).^2);
            fprintf(1, '\talpha = %g;\tJ(%g,%g) = %.8g; \n', ...
                alpha, temp_w, temp_b, Jnew);
            if Jnew < Jcurrent
                break;
            end
            alpha = 0.1*alpha; % <== reduce alpha
            if alpha == 0
                error("Sorry. It's not going well");
            end
        end
        Jcurrent = Jnew;
%         alpha = alphaOrig; % This resets alpha
        w =temp_w;
        b =temp_b;
    end
    fmt = "%10s: w = %g; b= %g\n";
%     fprintf(1, fmt, "Noise-free", wa, ba);
    fprintf(1, fmt, "Estimate", w, b);
    fig = figure(1);clf
    plot(x,y, 'linestyle', 'none','marker', 'o');
    hold on
% ..Show the least-squares fit
    t = [0.9*min(x),1.2*max(x)];
    plot(t,w*t+b, 'linestyle', '--', 'color', 'black');

相关问题