python 尽管进行了多类标签训练,但tflite推理仅预测一个标签

eeq64g8w  于 2022-12-21  发布在  Python
关注(0)|答案(1)|浏览(155)

我已经使用tensorflow为语音识别训练了一个多类分类器。然后使用tflite转换器转换了模型。模型可以预测,但它总是输出一个单一的类。我想问题出在推理代码上,因为.h5模型可以预测多类,没有任何问题。我已经在网上搜索了几天,寻找一些见解,但我不能完全弄清楚。以下是我的代码。任何建议都将不胜感激。

import sounddevice as sd
import numpy as np
import scipy.signal
import timeit
import python_speech_features

import tflite_runtime.interpreter as tflite

import importlib

# Parameters
debug_time = 0
debug_acc = 0
word_threshold = 0.95
rec_duration = 0.5   # 0.5
sample_length = 0.5
window_stride = 0.5  # 0.5
sample_rate = 8000   # The mic requires at least 44100 Hz to work
resample_rate = 8000
num_channels = 1
num_mfcc = 16

model_path = 'model.tflite'

mfccs_old = np.zeros((32, 25))

# Load model (interpreter)
interpreter = tflite.Interpreter(model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print(input_details)

# Filter and downsample
def decimate(signal, old_fs, new_fs):

    # Check to make sure we're downsampling
    if new_fs > old_fs:
        print("Error: target sample rate higher than original")
        return signal, old_fs

    # Downsampling is possible only by an integer factor
    dec_factor = old_fs / new_fs
    if not dec_factor.is_integer():
        print("Error: can only downsample by integer factor")

    # Do decimation
    resampled_signal = scipy.signal.decimate(signal, int(dec_factor))

    return resampled_signal, new_fs

# Callback that gets called every 0.5 seconds
def sd_callback(rec, frames, time, status):

    # Start timing for debug purposes
    start = timeit.default_timer()

    # Notify errors
    if status:
        print('Error:', status)

    global mfccs_old

    # Compute MFCCs
    mfccs = python_speech_features.base.mfcc(rec,
                                            samplerate=resample_rate,
                                            winlen=0.02,
                                            winstep=0.02,
                                            numcep=num_mfcc,
                                            nfilt=26,
                                            nfft=512, # 2048
                                            preemph=0.0,
                                            ceplifter=0,
                                            appendEnergy=True,
                                            winfunc=np.hanning)

    delta = python_speech_features.base.delta(mfccs, 2)

    mfccs_delta = np.append(mfccs, delta, axis=1)

    mfccs_new = mfccs_delta.transpose()
    mfccs = np.append(mfccs_old, mfccs_new, axis=1)
#    mfccs = np.insert(mfccs, [0], 0, axis=1)
    mfccs_old = mfccs_new

    # Run inference and make predictions
    in_tensor = np.float32(mfccs.reshape(1, mfccs.shape[0], mfccs.shape[1], 1))
    interpreter.set_tensor(input_details[0]['index'], in_tensor)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    val = np.amax(output_data)                      # DEFINED FOR BINARY CLASSIFICATION, CHANGE TO MULTICLASS
    ind = np.where(output_data == val)
    prediction = ind[1].astype(int)
    if val > word_threshold:
        print('index:', ind[1])
        print('accuracy', val, '/n')
        print(int(prediction))

    if debug_acc:
#        print('accuracy:', val)
#        print('index:', ind[1])
        print('out tensor:', output_data)
    if debug_time:
        print(timeit.default_timer() - start)

# Start recording from microphone
with sd.InputStream(channels=num_channels,
        samplerate=sample_rate,
        blocksize=int(sample_rate * rec_duration),
        callback=sd_callback):
    while True:
        pass
z2acfund

z2acfund1#

既然我想通了这个问题,我就自己回答,以防别人觉得有用。
问题是在你的数据集中没有“背景噪音”类。还要确保你有足够的背景噪音数据。如果你看谷歌的可教机器的音频项目(https://teachablemachine.withgoogle.com/train/audio),一个“背景噪音”类已经存在,你不能删除或禁用该类。
我用tensorflow的github示例(https://github.com/tensorflow/examples/blob/master/lite/examples/sound_classification/raspberry_pi/classify.py)和tensorflow网站(https://www.tensorflow.org/tutorials/audio/simple_audio)上提供的两个代码进行了测试,只要考虑到测试的特定环境,数据集中有足够的背景噪声样本,它们都能很好地用于预测。
我对tensorflow的github代码做了一些细微的修改,以输出类别名称和类别置信度。

# Loop until the user close the classification results plot.
  while True:
    # Wait until at least interval_between_inference seconds has passed since
    # the last inference.
    now = time.time()
    diff = now - last_inference_time
    if diff < interval_between_inference:
      time.sleep(pause_time)
      continue
    last_inference_time = now

    # Load the input audio and run classify.
    tensor_audio.load_from_audio_record(audio_record)
    result = classifier.classify(tensor_audio)
    for category in result.classifications[0].categories:
      print(category.category_name, category.score)

希望它对玩类似项目的人有帮助。

相关问题