keras 回调tensorflow 中不工作,停止训练

cidc1ykv  于 2023-04-06  发布在  其他
关注(0)|答案(2)|浏览(154)

我写了一个回调,当准确率达到99%时停止训练。但问题是我得到了这个错误。有时,如果我解决了这个错误,即使acuurqacy成为100%,回调也不会被调用。
在“NoneType”和“float”的示例之间不支持“〉”

class myCallback(tf.keras.callbacks.Callback):
        
        def on_epoch_end(self, epoch, logs={}):
            
            if(logs.get('accuracy') > 0.99):
                
                
               
               self.model.stop_training = True

def train_mnist():
    # Please write your code only where you are indicated.
    # please do not remove # model fitting inline comments.

    # YOUR CODE SHOULD START HERE

    # YOUR CODE SHOULD END HERE
    call = myCallback()
    mnist = tf.keras.datasets.mnist

    (x_train, y_train),(x_test, y_test) = mnist.load_data(path=path)
    # YOUR CODE SHOULD START
    x_train = x_train/255
    y_train = y_train/255
    # YOUR CODE SHOULD END HERE
    model = tf.keras.models.Sequential([
        # YOUR CODE SHOULD START HERE
          keras.layers.Flatten(input_shape=(28,28)),
          keras.layers.Dense(128,activation='relu'),
          keras.layers.Dense(10,activation='softmax')
        # YOUR CODE SHOULD END HERE
    ])

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    # model fitting
    history = model.fit(# YOUR CODE SHOULD START HERE
          x_train,y_train,epochs=9,callbacks=[call] )
    # model fitting
    return history.epoch, history.history['acc'][-1]
jvlzgdj9

jvlzgdj91#

上面代码的两个主要问题:

  • 在训练集上达到100%的准确率几乎总是意味着你的模型是过拟合的。这是BAD。你要做的是在.fit方法中指定validation_split=.2参数,并在验证集上寻找高准确率。
  • 您尝试在自定义回调中构建的内容已经在keras.callbacks.EarlyStopping中完成,它甚至可以选择在每个epoch中恢复到最佳整体模型。并且,默认情况下,如果您有验证分割,它会寻找验证精度,而不是训练精度。

所以,你应该这样做:停止使用自定义回调,它们需要一些掌握才能开始工作。使用EarlyStoppingrestore_best代替。like this始终使用validation_split并在验证集中寻找高准确性。Like in this quick example
使用内置的回调函数解决了你的问题吗?

q43xntqr

q43xntqr2#

我有同样的问题,但不是logs.get('accuracy'),我做了logs.get('acc'),它的工作。

相关问题