在pytorch中用我自己的图像测试我的mnist模型

o8x7eapl  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(115)

我按照geeks for geeks中的教程操作
我在测试时获得了很好的准确性

Epoch [1/10],Loss:0.1767,Validation Loss:15.1023,Accuracy:0.95,Validation Accuracy:0.95
Epoch [2/10],Loss:0.1112,Validation Loss:11.1439,Accuracy:0.97,Validation Accuracy:0.96
Epoch [3/10],Loss:0.0711,Validation Loss:8.6009,Accuracy:0.96,Validation Accuracy:0.98
Epoch [4/10],Loss:0.0606,Validation Loss:7.4755,Accuracy:0.97,Validation Accuracy:0.98
Epoch [5/10],Loss:0.0393,Validation Loss:6.7248,Accuracy:0.99,Validation Accuracy:0.99
Epoch [6/10],Loss:0.0311,Validation Loss:8.4266,Accuracy:1.00,Validation Accuracy:0.99
Epoch [7/10],Loss:0.0388,Validation Loss:6.4547,Accuracy:1.00,Validation Accuracy:0.99
Epoch [8/10],Loss:0.0216,Validation Loss:6.4336,Accuracy:1.00,Validation Accuracy:1.00
Epoch [9/10],Loss:0.0426,Validation Loss:6.8441,Accuracy:1.00,Validation Accuracy:0.98
Epoch [10/10],Loss:0.0167,Validation Loss:6.0449,Accuracy:0.99,Validation Accuracy:1.00

字符串
测试准确度:

Test Accuracy: 99.12%
              precision    recall  f1-score   support

           0       0.99      1.00      0.99       980
           1       0.99      1.00      0.99      1135
           2       0.99      0.99      0.99      1032
           3       0.99      0.99      0.99      1010
           4       0.99      0.99      0.99       982
           5       0.99      0.99      0.99       892
           6       1.00      0.99      0.99       958
           7       0.99      0.99      0.99      1028
           8       0.99      0.99      0.99       974
           9       0.99      0.98      0.99      1009

    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000


但当我用自己的图像测试这个模型时,

import cv2
import numpy as np
from PIL import *
image = cv2.imread(r"C:\Users\Sanmitha\Documents\first.jpg",0)
image = cv2.resize(image, (28,28))
batch = torch.tensor(image / 255).unsqueeze(0).float()
with torch.no_grad():
        batch = batch.to(device)
        output = model( batch )
        output = torch.argmax(output, 1)
        print(output)


输出为:

tensor([5])


图像为second.jpg
我还有一张图片要测试:first.jpg
在这两个图像中,我需要检测门号。答案应该是
在second.jpg中:110/50(至少只有数字)。在first.jpg中:38/42(至少只有数字)
我不知道怎么做这份工作。好心帮忙

4bbkushb

4bbkushb1#

从您的测试图像中可以明显看出,您并没有试图解决类似MNIST的问题,而是解决类似SVHN的问题。你需要使用一种不同的模型。你训练的那个只能在图像中只有一个数字时分类数字;它不能检测多个数字。因此,您需要像YOLO这样的对象检测模型。YOLO将为图像中的每个数字提供边界框,允许您按坐标对其进行排序并获得正确的顺序。有很多YOLO模型版本,对于其中一些,可以找到预训练的权重。我建议你尝试使用这个代码与预训练模型。

相关问题