R keras初学者问题predict_classes

r3i60tvu  于 2022-11-24  发布在  其他
关注(0)|答案(2)|浏览(194)

我以前没有使用Tensorflow或Keras的经验。我正在尝试按照教程https://tensorflow.rstudio.com/tutorials/beginners/

library(keras)

mnist <- dataset_mnist()
mnist$train$x <- mnist$train$x/255
mnist$test$x <- mnist$test$x/255

model <- keras_model_sequential() %>% 
  layer_flatten(input_shape = c(28, 28)) %>% 
  layer_dense(units = 128, activation = "relu") %>% 
  layer_dropout(0.2) %>% 
  layer_dense(10, activation = "softmax")

summary(model)

model %>% 
  compile(
    loss = "sparse_categorical_crossentropy",
    optimizer = "adam",
    metrics = "accuracy"
  )

#Note that compile and fit (which we are going to see next) modify the model object in place, unlike most R functions.

model %>% 
  fit(
    x = mnist$train$x, y = mnist$train$y,
    epochs = 5,
    validation_split = 0.3,
    verbose = 2
  )

predictions <- predict(model, mnist$test$x)
head(predictions, 2)

class_predictions <- predict(model, mnist$test$x) %>% k_argmax()
class_predictions

predict_classes已被弃用。在错误中,k_armax()被作为替代方法发布。但是,我不知道如何将predicted_classes(在本例中为数字0-9)作为向量在confusionMatrix中使用,就像其他R模型一样。如有任何帮助,将不胜感激。

3okqufwl

3okqufwl1#

对于这个问题,下面的代码是有效的

predictions <- predict(model, mnist$test$x)
pred_digits <- apply(predictions, 1, which.max) -1
confusionMatrix(as.factor(pred_digits), as.factor(mnist$test$y))

但是我仍然觉得很奇怪predict_classes被弃用而没有替换。到目前为止我看过的所有教程都使用它。

x3naxklr

x3naxklr2#

predict() %>% k_argmax()返回一个Tensor对象。要复制predict_classes()的结果,您需要将该Tensor对象转换为向量。您可以这样做:

class_predictions <- predict(model, mnist$test$x) %>% k_argmax() %>% as.vector()

此外,this page may be useful

相关问题