Flutter TFLite模型保持输出相同的结果

x7rlezfr  于 2023-05-23  发布在  Flutter
关注(0)|答案(1)|浏览(161)

我正在使用tensorflow和python构建CNN分类模型。该模型的输入形状为[1,50,7],由时间戳的第一列和其余列的传感器值组成。输出值为0或1以指定向左或向右的运动。然后,我将模型导出为TFLite模型,并使用tflite_flutter包(https://pub.dev/packages/tflite_flutter)在Flutter中使用它。
当我使用解释器run运行时,数据的输出总是0.0。但是,当我使用python运行时,我注意到在阅读一个csv数据后,需要添加

input_data = input_data.astype('float32')

正确运行模型,它输出一个0到1的值,这是我想要的,否则它将输出由于得到FLOAT64而不是FLOAT32而无法得到Tensor。因此,我尝试使用Flutter中的Float32List将数据转换为float32,但结果仍然是0.0。

List<Float32List> group32Float = [];
    for (var i = 0; i < 50; i++) {
       group32Float.add(Float32List.fromList(group[i]));
    }
    interpreter!.run([group32Float], [output]);

我的模型是这样的:

input_shape = (50, 7)

    model = Sequential()
    model.add(Conv1D(filters=32, kernel_size=3, activation='relu', padding='same', input_shape=input_shape))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=2))
    model.add(Dropout(0.25))
    model.add(Conv1D(filters=64, kernel_size=3, activation='relu', padding='same'))
    model.add(BatchNormalization())
    model.add(MaxPooling1D(pool_size=2))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(units=64, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(units=32, activation='relu', kernel_regularizer=regularizers.l2(0.001)))
    model.add(BatchNormalization())
    model.add(Dropout(0.5))
    model.add(Dense(1, activation='sigmoid'))

    optimizer = Adam(learning_rate=0.001)
    model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    early_stop = EarlyStopping(monitor='val_loss', patience=100)

    model.fit(X_train, y_train, epochs=1000, validation_data=(X_val, y_val), callbacks=[early_stop])

然后保存为TFLite:

model.save('model', save_format='tf')

    converter = tf.lite.TFLiteConverter.from_saved_model('model')
    tflite_model = converter.convert()

    with open('model.tflite', 'wb') as f:
        f.write(tflite_model)

我的问题是:为什么我在Flutter中的输出总是0.0?

flvlnr44

flvlnr441#

对于任何有同样问题的人,我已经找到了解决方案。我意识到我最初没有将输入数据设置为正确的类型,在我的情况下是:

List<List<double>> input

尝试检查输入和输出的变量类型,并确保它是发送到模型中的正确类型。

相关问题