tensorflow 尝试在Colab中拟合模型时出错,但在Jupyter笔记本中运行正常

tvokkenx  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(241)

我试着在Colab中用我的训练和测试数据(大约7GB)来拟合模型,因为当我在本地机器上使用Jupyter笔记本时,训练时间太长了。然而,当我试着使用Colab时,它给了我下面的错误,但它在Jupyter中工作得很好。

Epoch 1/20
---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
<ipython-input-60-677f43e317d7> in <module>()
      6   epochs=20,
      7   steps_per_epoch=len(training_set),
----> 8   validation_steps=len(testing_set)
      9 )

1 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53     ctx.ensure_initialized()
     54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:
     57     if name is not None:

UnimplementedError: Graph execution error:

我已经在colab中将运行时类型更改为GPU,我目前没有使用Colab Pro版本,数据集存储在Google驱动器中。我有点困惑,因为代码在Jupyter笔记本中工作正常,没有任何问题。
您可以使用下面的GitHub链接查看colab文件。https://github.com/ArchieVon/DL/blob/main/ResNet_Test1.ipynb

cyvaqqii

cyvaqqii1#

将IMAGE_SIZE从列表更改为元组可解决此问题。请在下面查找工作代码。

from tensorflow.keras.layers import Input, Lambda, Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator,load_img
from tensorflow.keras.models import Sequential
import numpy as np
from glob import glob

IMAGE_SIZE = (224, 224, 3)

train_path = '/content/dogs_vs_cats/train'
valid_path = '/content/dogs_vs_cats/test'

resnet = ResNet50(input_shape=IMAGE_SIZE, weights='imagenet', include_top=False)
for layer in resnet.layers:
    layer.trainable = False

folders = glob('/content/dogs_vs_cats/train/*')
len(folders) 

x = Flatten()(resnet.output)
prediction = Dense(len(folders), activation='softmax')(x)

model = Model(inputs=resnet.input, outputs=prediction)
#model.summary()

model.compile(
  loss='categorical_crossentropy',
  optimizer='adam',
  metrics=['accuracy']
)

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)

test_datagen = ImageDataGenerator(rescale = 1./255)

training_set = train_datagen.flow_from_directory(train_path,
                                                 target_size = (224, 224),
                                                 batch_size = 32,
                                                 class_mode = 'categorical')
testing_set = test_datagen.flow_from_directory(valid_path,
                                            target_size = (224, 224),
                                            batch_size = 32,
                                            class_mode = 'categorical')

len(testing_set)
len(training_set)

# fit the model
# Run the cell. It will take some time to execute
r = model.fit(
  training_set,
  validation_data=testing_set,
  epochs=1,
  steps_per_epoch=len(training_set),
  validation_steps=len(testing_set)
)

输出如下:

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
94773248/94765736 [==============================] - 0s 0us/step
94781440/94765736 [==============================] - 0s 0us/step
Found 20000 images belonging to 2 classes.
Found 5000 images belonging to 2 classes.
625/625 [==============================] - ETA: 0s - loss: 0.9511 - accuracy: 0.6004

如果问题仍然存在,请告诉我们。谢谢!

相关问题