C语言 实现具有交叉熵损失的Softmax输出层

sbdsn5lh  于 10个月前  发布在  其他
关注(0)|答案(1)|浏览(116)

我正在玩这个repo(https://github.com/SnailWalkerYC/LeNet-5_Speed_Up),并试图学习NN的细节。这个repo在C和CUDA中实现了LeNet 5。我现在专注于CPU部分及其在seq/中的代码。我迷路的一个特别地方是seq/lenet. c中的这个函数

static inline void softmax(double input[OUTPUT], double loss[OUTPUT], int label, int count){
      double inner = 0; 
      for (int i = 0; i < count; ++i){
           double res = 0;
           for (int j = 0; j < count; ++j){
              res += exp(input[j] - input[i]);
           }
           loss[i] = 1. / res;
           inner -= loss[i] * loss[i];
       }
       inner += loss[label];
       for (int i = 0; i < count; ++i){
          loss[i] *= (i == label) - loss[i] - inner;
       }
}

字符串
因为没有评论,我花了一些时间来理解这个函数。最后我发现它是计算MSE损失函数关于softmax层输入的导数。
然后我尝试使用交叉熵损失函数和softmax,所以我用下面的函数来代替上面的函数。

static inline void softmax(double input[OUTPUT], double loss[OUTPUT], int label, int count)
{
    double inner = 0;
    double max_input = -INFINITY;

    // Find the maximum input value to prevent numerical instability
    for (int i = 0; i < count; ++i)
    {
        if (input[i] > max_input)
            max_input = input[i];
    }

    // Compute softmax and cross-entropy loss
    double sum_exp = 0;
    for (int i = 0; i < count; ++i)
    {
        double exp_val = exp(input[i] - max_input);
        sum_exp += exp_val;
        loss[i] = exp_val;
    }

    double softmax_output[OUTPUT];
    for (int i = 0; i < count; ++i)
    {
        loss[i] /= sum_exp;
        softmax_output[i] = loss[i];
    }

    // Compute cross-entropy loss and derivatives
    inner = -log(softmax_output[label]);
    for (int i = 0; i < count; ++i)
    {
        
    loss[i] = softmax_output[i] - (i == label);
    }
}


然而,使用我的softmax()函数版本,MNIST识别不起作用。原始版本达到了>96%的准确率。我的交叉熵损失代码有什么问题吗?

gzszwxb4

gzszwxb41#

好吧,我会回答我自己的问题。
我设法使交叉熵损失与softmax一起工作。有两个地方需要调整:
1.交叉熵w.r.t.对softmax输入的导数具有(prediction - label)的形式,如在许多地方所见。

loss[i] = (i == label) - softmax_output[i]; // opposite of the common form

字符串
而不是

loss[i] = softmax_output[i] - (i == label); // conforms to the common form


这是因为CNN更新权重和偏差的方式是违反直觉的:

double k = ALPHA / batchSize;
FOREACH(i, GETCOUNT(LeNet5))
    ((double *)lenet)[i] += k * buffer[i];


因此,要么使用损失导数的相反形式,要么改变权重更新,

((double *)lenet)[i] -= k * buffer[i];


1.即使有上述修复,这个MNIST CNN仍然不起作用。有时训练过程中的损失并没有下降。我注意到的另一个奇怪的事情是学习率(代码中的Alpha)。这段代码将学习率硬编码为0.5。我见过的大多数地方都使用0.1
有了上述1)和2),模型现在预测一致性> 96水平。
在这个过程中,我学到了很多东西。

相关问题