Paddle 使用2.0rc高层API训练和评估模型,准确度始终在11%左右?但如果使用基础API训练和评估模型,准确度能达到70%,我使用高层API中有什么错误吗?

6ju8rftf  于 2021-12-07  发布在  Java
关注(0)|答案(2)|浏览(202)

用PaddlePaddle 2.0rc的高层API modle.fit和model.evaluate来训练和评估模型,准确度

import paddle
import paddle.nn.functional as F
import numpy as np

cifar10_train = paddle.vision.datasets.cifar.Cifar10(mode='train', transform=None)
cifar10_test = paddle.vision.datasets.cifar.Cifar10(mode='test', transform=None)

class MyNet(paddle.nn.Layer):
    def __init__(self, num_classes=10):
        super(MyNet, self).__init__()

        self.conv1 = paddle.nn.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3))
        self.pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)

        self.conv2 = paddle.nn.Conv2D(in_channels=32, out_channels=64, kernel_size=(3,3))
        self.pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)

        self.conv3 = paddle.nn.Conv2D(in_channels=64, out_channels=64, kernel_size=(3,3))

        self.flatten = paddle.nn.Flatten()

        self.linear1 = paddle.nn.Linear(in_features=1024, out_features=64)
        self.linear2 = paddle.nn.Linear(in_features=64, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = F.relu(x)

        x = self.flatten(x)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

model=paddle.Model(MyNet())

optim = paddle.optimizer.Adam(
    learning_rate=0.001, parameters=model.parameters())

model.prepare(
    optim,
    loss=paddle.nn.CrossEntropyLoss(),
    metrics=paddle.metric.Accuracy())

model.fit(cifar10_train,
          cifar10_test,
          epochs=50,
          batch_size=32,
          log_freq=1000)
model.evaluate(cifar10_test, batch_size=64, verbose=1)

训练和评估结果为:

Epoch 1/50
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return (isinstance(seq, collections.Sequence) and
step 1000/1563 - loss: 1.4179 - acc: 0.1062 - 8ms/step
step 1563/1563 - loss: 1.4963 - acc: 0.1081 - 8ms/step
Eval begin...
step 313/313 - loss: 1.1943 - acc: 0.1180 - 5ms/step
Eval samples: 10000
Epoch 2/50
step 1000/1563 - loss: 1.7969 - acc: 0.1115 - 9ms/step
step 1563/1563 - loss: 1.4865 - acc: 0.1138 - 9ms/step
Eval begin...
step 313/313 - loss: 1.3613 - acc: 0.1171 - 6ms/step
Eval samples: 10000
Epoch 3/50
step 1000/1563 - loss: 1.1616 - acc: 0.1100 - 9ms/step
step 1563/1563 - loss: 0.8361 - acc: 0.1115 - 9ms/step
Eval begin...
step 313/313 - loss: 1.1157 - acc: 0.1169 - 6ms/step
Eval samples: 10000
Epoch 4/50
step 1000/1563 - loss: 1.0758 - acc: 0.1135 - 10ms/step
step 1563/1563 - loss: 1.1662 - acc: 0.1157 - 10ms/step
Eval begin...
step 313/313 - loss: 1.0637 - acc: 0.1170 - 6ms/step
Eval samples: 10000
Epoch 5/50
step 1000/1563 - loss: 1.0203 - acc: 0.1121 - 9ms/step
step 1563/1563 - loss: 1.6569 - acc: 0.1141 - 9ms/step
Eval begin...
step 313/313 - loss: 0.9319 - acc: 0.1174 - 6ms/step
Eval samples: 10000
Epoch 6/50
step 1000/1563 - loss: 0.5771 - acc: 0.1214 - 9ms/step
step 1563/1563 - loss: 0.7482 - acc: 0.1204 - 9ms/step
Eval begin...
step 313/313 - loss: 0.8782 - acc: 0.1200 - 5ms/step
Eval samples: 10000
Epoch 7/50
step 1000/1563 - loss: 0.7608 - acc: 0.1191 - 9ms/step
step 1563/1563 - loss: 1.1097 - acc: 0.1188 - 9ms/step
Eval begin...
step 313/313 - loss: 1.2158 - acc: 0.1211 - 7ms/step
Eval samples: 10000
Epoch 8/50
step 1000/1563 - loss: 0.6291 - acc: 0.1177 - 9ms/step
step 1563/1563 - loss: 1.3995 - acc: 0.1187 - 9ms/step
Eval begin...
step 313/313 - loss: 1.2460 - acc: 0.1203 - 6ms/step
Eval samples: 10000
Epoch 9/50
step 1000/1563 - loss: 1.3622 - acc: 0.1205 - 7ms/step
step 1563/1563 - loss: 0.4886 - acc: 0.1210 - 6ms/step
Eval begin...
step 313/313 - loss: 0.8871 - acc: 0.1176 - 3ms/step
Eval samples: 10000
Epoch 10/50
step 1000/1563 - loss: 1.2783 - acc: 0.1183 - 5ms/step
step 1563/1563 - loss: 1.0119 - acc: 0.1197 - 5ms/step
Eval begin...
step 313/313 - loss: 1.1044 - acc: 0.1193 - 3ms/step
Eval samples: 10000
Epoch 11/50
step 1000/1563 - loss: 0.7276 - acc: 0.1178 - 4ms/step
step 1563/1563 - loss: 0.6809 - acc: 0.1186 - 4ms/step
Eval begin...
step 313/313 - loss: 1.6448 - acc: 0.1189 - 3ms/step
Eval samples: 10000
Epoch 12/50
step 1000/1563 - loss: 0.6836 - acc: 0.1197 - 5ms/step
step 1563/1563 - loss: 0.4037 - acc: 0.1213 - 5ms/step
Eval begin...
step 313/313 - loss: 1.3373 - acc: 0.1209 - 3ms/step
Eval samples: 10000
Epoch 13/50
step 1000/1563 - loss: 0.6929 - acc: 0.1203 - 5ms/step
step 1563/1563 - loss: 0.4770 - acc: 0.1195 - 4ms/step
Eval begin...
step 313/313 - loss: 1.2987 - acc: 0.1192 - 3ms/step
Eval samples: 10000
Epoch 14/50
step 1000/1563 - loss: 0.4943 - acc: 0.1211 - 5ms/step
step 1563/1563 - loss: 0.3968 - acc: 0.1202 - 5ms/step
Eval begin...
step 313/313 - loss: 1.4327 - acc: 0.1220 - 6ms/step
Eval samples: 10000

但如果使用基础API来训练和评估相同的模型,准确度却能够提高到70%。

import paddle
import paddle.nn.functional as F
import numpy as np

cifar10_train = paddle.vision.datasets.cifar.Cifar10(mode='train', transform=None)

class MyNet(paddle.nn.Layer):
    def __init__(self, num_classes=10):
        super(MyNet, self).__init__()

        self.conv1 = paddle.nn.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3))
        self.pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)

        self.conv2 = paddle.nn.Conv2D(in_channels=32, out_channels=64, kernel_size=(3,3))
        self.pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)

        self.conv3 = paddle.nn.Conv2D(in_channels=64, out_channels=64, kernel_size=(3,3))

        self.flatten = paddle.nn.Flatten()

        self.linear1 = paddle.nn.Linear(in_features=1024, out_features=64)
        self.linear2 = paddle.nn.Linear(in_features=64, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = F.relu(x)

        x = self.flatten(x)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

epoch_num = 50
batch_size = 32
learning_rate = 0.001

val_acc_history = []
val_loss_history = []

def train(model):
    print('start training ... ')
    # turn into training mode
    model.train()

    opt = paddle.optimizer.Adam(learning_rate=learning_rate,
                                parameters=model.parameters())

    train_loader = paddle.io.DataLoader(cifar10_train,
                                        shuffle=True,
                                        batch_size=batch_size)

    cifar10_test = paddle.vision.datasets.cifar.Cifar10(mode='test', transform=None)
    valid_loader = paddle.io.DataLoader(cifar10_test, batch_size=batch_size)

    for epoch in range(epoch_num):
        for batch_id, data in enumerate(train_loader()):
            x_data = paddle.cast(data[0], 'float32')
            x_data = paddle.reshape(x_data, (-1, 3, 32, 32)) / 255.0
            y_data = paddle.cast(data[1], 'int64')
            y_data = paddle.reshape(y_data, (-1, 1))
            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)

            if batch_id % 1000 == 0:
                print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))
            loss.backward()
            opt.step()
            opt.clear_grad()

        # evaluate model after one epoch
        model.eval()
        accuracies = []
        losses = []
        for batch_id, data in enumerate(valid_loader()):
            x_data = paddle.cast(data[0], 'float32')
            x_data = paddle.reshape(x_data, (-1, 3, 32, 32)) / 255.0
            y_data = paddle.cast(data[1], 'int64')
            y_data = paddle.reshape(y_data, (-1, 1))

            logits = model(x_data)
            loss = F.cross_entropy(logits, y_data)
            acc = paddle.metric.accuracy(logits, y_data)
            accuracies.append(np.mean(acc.numpy()))
            losses.append(np.mean(loss.numpy()))

        avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
        print("[validation] accuracy/loss: {}/{}".format(avg_acc, avg_loss))
        val_acc_history.append(avg_acc)
        val_loss_history.append(avg_loss)
        model.train()

model = MyNet()
train(model)

使用基础API训练和评估结果为:

start training ...
epoch: 0, batch_id: 0, loss is: [2.3293128]
epoch: 0, batch_id: 1000, loss is: [1.8194305]
[validation] accuracy/loss: 0.5231629610061646/1.3090485334396362
epoch: 1, batch_id: 0, loss is: [1.5358344]
epoch: 1, batch_id: 1000, loss is: [1.2667058]
[validation] accuracy/loss: 0.6137180328369141/1.0972559452056885
epoch: 2, batch_id: 0, loss is: [1.3501728]
epoch: 2, batch_id: 1000, loss is: [0.70167136]
[validation] accuracy/loss: 0.5982428193092346/1.1266995668411255
epoch: 3, batch_id: 0, loss is: [1.26873]
epoch: 3, batch_id: 1000, loss is: [0.9860645]
[validation] accuracy/loss: 0.6560503244400024/0.9876276254653931
epoch: 4, batch_id: 0, loss is: [1.0141135]
epoch: 4, batch_id: 1000, loss is: [1.047922]
[validation] accuracy/loss: 0.6822084784507751/0.9094789028167725
epoch: 5, batch_id: 0, loss is: [0.8284477]
epoch: 5, batch_id: 1000, loss is: [1.2929298]
[validation] accuracy/loss: 0.6776158213615417/0.9294835329055786
epoch: 6, batch_id: 0, loss is: [0.5874635]
epoch: 6, batch_id: 1000, loss is: [0.81808865]
[validation] accuracy/loss: 0.6781150102615356/0.9467935562133789
epoch: 7, batch_id: 0, loss is: [0.79639006]
epoch: 7, batch_id: 1000, loss is: [0.5610222]
[validation] accuracy/loss: 0.6814097166061401/0.9082037806510925
epoch: 8, batch_id: 0, loss is: [0.8221557]
epoch: 8, batch_id: 1000, loss is: [0.84624374]
[validation] accuracy/loss: 0.7076677083969116/0.8675845265388489
epoch: 9, batch_id: 0, loss is: [0.6109121]
epoch: 9, batch_id: 1000, loss is: [0.5778799]
[validation] accuracy/loss: 0.706569492816925/0.8552684783935547
epoch: 10, batch_id: 0, loss is: [0.4108306]
epoch: 10, batch_id: 1000, loss is: [0.73070085]
[validation] accuracy/loss: 0.7057707905769348/0.9030559659004211
epoch: 11, batch_id: 0, loss is: [0.7420476]
epoch: 11, batch_id: 1000, loss is: [0.82045287]
[validation] accuracy/loss: 0.713957667350769/0.8668591380119324
epoch: 12, batch_id: 0, loss is: [0.57150733]
epoch: 12, batch_id: 1000, loss is: [1.0905309]
[validation] accuracy/loss: 0.7201477885246277/0.8755870461463928
epoch: 13, batch_id: 0, loss is: [0.355359]
epoch: 13, batch_id: 1000, loss is: [0.85425407]
[validation] accuracy/loss: 0.7063698172569275/0.913926362991333
epoch: 14, batch_id: 0, loss is: [0.4514779]
epoch: 14, batch_id: 1000, loss is: [0.42346495]
[validation] accuracy/loss: 0.708765983581543/0.88392573595047
epoch: 15, batch_id: 0, loss is: [0.3419215]
epoch: 15, batch_id: 1000, loss is: [0.63300633]
[validation] accuracy/loss: 0.7082667946815491/0.950889527797699
epoch: 16, batch_id: 0, loss is: [0.3893841]
epoch: 16, batch_id: 1000, loss is: [0.6329537]
[validation] accuracy/loss: 0.7066693305969238/0.9287868142127991
epoch: 17, batch_id: 0, loss is: [0.21955445]
epoch: 17, batch_id: 1000, loss is: [0.45696437]
[validation] accuracy/loss: 0.7103634476661682/0.9799591898918152
epoch: 18, batch_id: 0, loss is: [0.22373167]
epoch: 18, batch_id: 1000, loss is: [0.5783404]
[validation] accuracy/loss: 0.7018769979476929/1.0424772500991821
epoch: 19, batch_id: 0, loss is: [0.4076147]
epoch: 19, batch_id: 1000, loss is: [0.26885363]
[validation] accuracy/loss: 0.7029752135276794/1.0611885786056519
epoch: 20, batch_id: 0, loss is: [0.6359902]
epoch: 20, batch_id: 1000, loss is: [0.3711901]
[validation] accuracy/loss: 0.7094648480415344/1.0525825023651123
epoch: 21, batch_id: 0, loss is: [0.35774338]
epoch: 21, batch_id: 1000, loss is: [0.37585694]
[validation] accuracy/loss: 0.7046725153923035/1.066388487815857
epoch: 22, batch_id: 0, loss is: [0.4047526]
epoch: 22, batch_id: 1000, loss is: [0.51240903]
[validation] accuracy/loss: 0.7059704661369324/1.1337810754776
epoch: 23, batch_id: 0, loss is: [0.35482845]
epoch: 23, batch_id: 1000, loss is: [0.43204474]
[validation] accuracy/loss: 0.6940894722938538/1.2031739950180054
epoch: 24, batch_id: 0, loss is: [0.32811758]
epoch: 24, batch_id: 1000, loss is: [0.46759185]
[validation] accuracy/loss: 0.6956868767738342/1.2191218137741089
epoch: 25, batch_id: 0, loss is: [0.24784622]
epoch: 25, batch_id: 1000, loss is: [0.27353552]
[validation] accuracy/loss: 0.6859025359153748/1.3275631666183472
epoch: 26, batch_id: 0, loss is: [0.5800693]
epoch: 26, batch_id: 1000, loss is: [0.37760702]
[validation] accuracy/loss: 0.7023761868476868/1.3081507682800293
epoch: 27, batch_id: 0, loss is: [0.18420647]
epoch: 27, batch_id: 1000, loss is: [0.07969378]
[validation] accuracy/loss: 0.7020766735076904/1.3345941305160522
epoch: 28, batch_id: 0, loss is: [0.76367354]
epoch: 28, batch_id: 1000, loss is: [0.37999988]
[validation] accuracy/loss: 0.7025758624076843/1.3571847677230835
epoch: 29, batch_id: 0, loss is: [0.19946706]
epoch: 29, batch_id: 1000, loss is: [0.1533969]
[validation] accuracy/loss: 0.6838058829307556/1.4618871212005615
epoch: 30, batch_id: 0, loss is: [0.24770181]
epoch: 30, batch_id: 1000, loss is: [0.24702355]
[validation] accuracy/loss: 0.6890974640846252/1.4655603170394897
epoch: 31, batch_id: 0, loss is: [0.2026419]
epoch: 31, batch_id: 1000, loss is: [0.2235108]
[validation] accuracy/loss: 0.6901956796646118/1.525275707244873
epoch: 32, batch_id: 0, loss is: [0.09683963]
epoch: 32, batch_id: 1000, loss is: [0.19227877]
[validation] accuracy/loss: 0.6925918459892273/1.638543725013733
epoch: 33, batch_id: 0, loss is: [0.10310488]
epoch: 33, batch_id: 1000, loss is: [0.44001365]
[validation] accuracy/loss: 0.6937899589538574/1.635685682296753
epoch: 34, batch_id: 0, loss is: [0.07425513]
epoch: 34, batch_id: 1000, loss is: [0.23257709]
[validation] accuracy/loss: 0.693989634513855/1.713061809539795
epoch: 35, batch_id: 0, loss is: [0.1774381]
epoch: 35, batch_id: 1000, loss is: [0.16438515]
[validation] accuracy/loss: 0.6894968152046204/1.75552499294281
epoch: 36, batch_id: 0, loss is: [0.11889877]
epoch: 36, batch_id: 1000, loss is: [0.20611458]
[validation] accuracy/loss: 0.6761181950569153/1.868609070777893
epoch: 37, batch_id: 0, loss is: [0.25529495]
epoch: 37, batch_id: 1000, loss is: [0.03140061]
[validation] accuracy/loss: 0.6914936304092407/1.8362085819244385
epoch: 38, batch_id: 0, loss is: [0.10326625]
epoch: 38, batch_id: 1000, loss is: [0.11940454]
[validation] accuracy/loss: 0.6888977885246277/1.868736982345581
epoch: 39, batch_id: 0, loss is: [0.12311646]
epoch: 39, batch_id: 1000, loss is: [0.1625667]
[validation] accuracy/loss: 0.6854033470153809/1.9806232452392578
epoch: 40, batch_id: 0, loss is: [0.14498284]
epoch: 40, batch_id: 1000, loss is: [0.16378544]
[validation] accuracy/loss: 0.6940894722938538/2.0358874797821045
epoch: 41, batch_id: 0, loss is: [0.31027701]
epoch: 41, batch_id: 1000, loss is: [0.10543761]
[validation] accuracy/loss: 0.6871006488800049/2.0204505920410156
epoch: 42, batch_id: 0, loss is: [0.07480161]
epoch: 42, batch_id: 1000, loss is: [0.27085626]
[validation] accuracy/loss: 0.6767172813415527/2.0724434852600098
epoch: 43, batch_id: 0, loss is: [0.03261568]
epoch: 43, batch_id: 1000, loss is: [0.03589397]
[validation] accuracy/loss: 0.6802116632461548/2.136470317840576
epoch: 44, batch_id: 0, loss is: [0.04260045]
epoch: 44, batch_id: 1000, loss is: [0.3140782]
[validation] accuracy/loss: 0.6779153347015381/2.1353864669799805
epoch: 45, batch_id: 0, loss is: [0.16577259]
epoch: 45, batch_id: 1000, loss is: [0.1480903]
[validation] accuracy/loss: 0.6842052936553955/2.153289794921875
epoch: 46, batch_id: 0, loss is: [0.05079889]
epoch: 46, batch_id: 1000, loss is: [0.179773]
[validation] accuracy/loss: 0.6851038336753845/2.2048838138580322
epoch: 47, batch_id: 0, loss is: [0.05230485]
epoch: 47, batch_id: 1000, loss is: [0.13677706]
[validation] accuracy/loss: 0.6839057803153992/2.2221930027008057
epoch: 48, batch_id: 0, loss is: [0.15611836]
epoch: 48, batch_id: 1000, loss is: [0.20350382]
[validation] accuracy/loss: 0.6836062073707581/2.284668445587158
epoch: 49, batch_id: 0, loss is: [0.18042356]
epoch: 49, batch_id: 1000, loss is: [0.09372107]
[validation] accuracy/loss: 0.6873003244400024/2.4081358909606934

是否是我在使用高层API中出现了问题?求解答,感谢。

ttisahbt

ttisahbt1#

您好,我们已经收到了您的问题,会安排技术人员尽快解答您的问题,请耐心等待。请您再次检查是否提供了清晰的问题描述、复现代码、环境&版本、报错信息等。同时,您也可以通过查看官网API文档常见问题历史IssueAI社区来寻求解答。祝您生活愉快~

Hi! We've received your issue and please be patient to get responded. We will arrange technicians to answer your questions as soon as possible. Please make sure that you have posted enough message to demo your request. You may also check out the APIFAQGithub Issue and AI community to get the answer.Have a nice day!

iq0todco

iq0todco2#

似乎是已知问题,已经在最新的版本中修复了。可以尝试一下2.0.0rc1

相关问题