Keras无效参数:锁定时所需的可传播形状(未知)

zxlwwiss  于 2022-11-30  发布在  其他
关注(0)|答案(1)|浏览(221)

我一直在训练我的模型,方法是使用存储在hdf5文件中的数据的训练和测试生成器来填充fit()方法。(约25,000张图片和标签)。我最近已经处理了负面的情况下到一个新的hdf5文件与类似数量的图像,然而,在更新生成器读取两个文件后,抓住一半的批量大小数量的图像从每一套,并且将它们合并在一起,则在单个时期之后训练崩溃为Invalid argument: required broadcastable shapes at loc(unknown)
我已经确保模型输出、生成器输出和数据类型都是正确的(模型:UNet,sigmoid,classes=1,output shape =(...,1),output type = bool),正如同一问题的其他答案所暗示的那样,但我仍然得到同样的错误。

训练.py

db = h5py.File(db_output_path, 'r')
a = db['data'][200]
b = db['labels'][200]

db_neg = h5py.File(db_negatives_path, 'r')
train_neg_gen = kfold.split(db_neg['data'])
neg_idx = []
for t in train_neg_gen:
    neg_idx.append(t)
batch_size=16

for train, test in kfold.split(db['data'], db['labels']):
    
    train_neg_idx, test_neg_idx = neg_idx[fold_no-1]
    
    gen_train = create_hdf5_generator(db_output_path, train, batch_size, CLASSES, db_negatives_path, train_neg_idx)   
    gen_val = create_hdf5_generator(db_output_path, test, batch_size, CLASSES, db_negatives_path, test_neg_idx)
    model.load_weights('weights/weights_2022-11-20.h5')
    
    # Generate a print
    print('------------------------------------------------------------------------')
    print(f'Training for fold {fold_no} ...')
    
    steps_per_epoch = (2*len(train))//batch_size
    validation_steps= (2*len(test))//batch_size
    
    results = model.fit(gen_train,
                        epochs=10, validation_data=gen_val,
                        steps_per_epoch=steps_per_epoch,
                        validation_steps=validation_steps,
                        callbacks=callbacks)
        
    # Increase fold number
    fold_no = fold_no + 1

发电机

def create_hdf5_generator(db_path, indices, batch_size, classes, neg_db_path=None, neg_indices=None):
    db = h5py.File(db_path)
    neg_db = h5py.File(neg_db_path)
    
    while True:
        if neg_indices is not None:
            skip = batch_size//2
            restart = 0
            for i in np.arange(0, len(indices), skip):
                j = i
                #j tracks neg_db indices which is smaller in size than positive indices tracked by i
                if i >= len(neg_indices):
                    j = restart
                    restart += skip
                    
                images = db['data'][indices[i:i+skip]]
                labels = db['labels'][indices[i:i+skip]]
                
                neg_images = neg_db['data'][neg_indices[j:j+skip]]
                neg_labels = np.zeros(labels.shape).astype(np.float32)
                
                images_concat = np.concatenate((images, neg_images), axis=0)
                labels_concat = np.concatenate((labels, neg_labels), axis=0)
                
                np.random.seed(123)
                np.random.shuffle(images_concat)
                np.random.seed(123)
                np.random.shuffle(labels_concat)
                
                
                yield images_concat, labels_concat.astype(bool)

控制台输出

------------------------------------------------------------------------
Training for fold 1 ...
Epoch 1/10
2773/2774 [============================>.] - ETA: 0s - loss: 0.1157 - mean_io_u_2: 0.4766  Traceback (most recent call last):

  File "C:\Users\Noam\github\proj\train.py", line 181, in <module>
    results = model.fit(gen_train,

  File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1214, in fit
    val_logs = self.evaluate(

  File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1489, in evaluate
    tmp_logs = self.test_function(iterator)

  File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 889, in __call__
    result = self._call(*args, **kwds)

  File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 924, in _call
    results = self._stateful_fn(*args, **kwds)

  File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 3023, in __call__
    return graph_function._call_flat(

  File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 1960, in _call_flat
    return self._build_call_outputs(self._inference_function.call(

  File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\function.py", line 591, in call
    outputs = execute.execute(

  File "C:\Users\Noam\anaconda3\lib\site-packages\tensorflow\python\eager\execute.py", line 59, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  required broadcastable shapes at loc(unknown)
     [[node binary_crossentropy/logistic_loss/mul (defined at C:\Users\Noam\github\proj\train.py:181) ]]
     [[confusion_matrix/assert_non_negative_1/assert_less_equal/Assert/AssertGuard/pivot_f/_12/_33]]
  (1) Invalid argument:  required broadcastable shapes at loc(unknown)
     [[node binary_crossentropy/logistic_loss/mul (defined at C:\Users\Noam\github\proj\train.py:181) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_test_function_79850]

Function call stack:
test_function -> test_function


2022-11-27 19:22:08.581553: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudart64_110.dll
2022-11-27 19:22:18.055899: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library nvcuda.dll
2022-11-27 19:22:18.073779: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.8GHz coreCount: 82 deviceMemorySize: 24.00GiB deviceMemoryBandwidth: 871.81GiB/s
2022-11-27 19:22:18.073819: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudart64_110.dll
2022-11-27 19:22:18.093917: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublas64_11.dll
2022-11-27 19:22:18.093939: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublasLt64_11.dll
2022-11-27 19:22:18.100311: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cufft64_10.dll
2022-11-27 19:22:18.102617: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library curand64_10.dll
2022-11-27 19:22:18.105904: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cusolver64_11.dll
2022-11-27 19:22:18.111640: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cusparse64_11.dll
2022-11-27 19:22:18.112034: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudnn64_8.dll
2022-11-27 19:22:18.112100: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2022-11-27 19:22:18.112463: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-27 19:22:18.113094: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1733] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: NVIDIA GeForce RTX 3090 computeCapability: 8.6
coreClock: 1.8GHz coreCount: 82 deviceMemorySize: 24.00GiB deviceMemoryBandwidth: 871.81GiB/s
2022-11-27 19:22:18.113127: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1871] Adding visible gpu devices: 0
2022-11-27 19:22:18.495306: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1258] Device interconnect StreamExecutor with strength 1 edge matrix:
2022-11-27 19:22:18.495334: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1264]      0 
2022-11-27 19:22:18.495341: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1277] 0:   N 
2022-11-27 19:22:18.495486: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1418] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 21670 MB memory) -> physical GPU (device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:01:00.0, compute capability: 8.6)
2022-11-27 19:22:21.753068: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2022-11-27 19:22:23.357640: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cudnn64_8.dll
2022-11-27 19:22:23.868767: I tensorflow/stream_executor/cuda/cuda_dnn.cc:359] Loaded cuDNN version 8201
2022-11-27 19:22:24.730172: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublas64_11.dll
2022-11-27 19:22:25.324257: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library cublasLt64_11.dll
2022-11-27 19:23:30.675901: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 19:29:53.026090: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 19:46:47.257803: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 19:50:09.871857: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 19:51:28.339643: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 20:22:00.445508: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 20:30:20.786297: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 20:45:59.779202: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)
2022-11-27 21:06:14.203518: W tensorflow/core/framework/op_kernel.cc:1755] Invalid argument: required broadcastable shapes at loc(unknown)

联合国网络

sigmoid
binary_crossentropy
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 128, 128, 3) 0                                            
__________________________________________________________________________________________________
conv2d_57 (Conv2D)              (None, 128, 128, 32) 896         input_4[0][0]                    
__________________________________________________________________________________________________
dropout_27 (Dropout)            (None, 128, 128, 32) 0           conv2d_57[0][0]                  
__________________________________________________________________________________________________
conv2d_58 (Conv2D)              (None, 128, 128, 32) 9248        dropout_27[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, 64, 64, 32)   0           conv2d_58[0][0]                  
__________________________________________________________________________________________________
conv2d_59 (Conv2D)              (None, 64, 64, 64)   18496       max_pooling2d_12[0][0]           
__________________________________________________________________________________________________
dropout_28 (Dropout)            (None, 64, 64, 64)   0           conv2d_59[0][0]                  
__________________________________________________________________________________________________
conv2d_60 (Conv2D)              (None, 64, 64, 64)   36928       dropout_28[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_13 (MaxPooling2D) (None, 32, 32, 64)   0           conv2d_60[0][0]                  
__________________________________________________________________________________________________
conv2d_61 (Conv2D)              (None, 32, 32, 128)  73856       max_pooling2d_13[0][0]           
__________________________________________________________________________________________________
dropout_29 (Dropout)            (None, 32, 32, 128)  0           conv2d_61[0][0]                  
__________________________________________________________________________________________________
conv2d_62 (Conv2D)              (None, 32, 32, 128)  147584      dropout_29[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_14 (MaxPooling2D) (None, 16, 16, 128)  0           conv2d_62[0][0]                  
__________________________________________________________________________________________________
conv2d_63 (Conv2D)              (None, 16, 16, 256)  295168      max_pooling2d_14[0][0]           
__________________________________________________________________________________________________
dropout_30 (Dropout)            (None, 16, 16, 256)  0           conv2d_63[0][0]                  
__________________________________________________________________________________________________
conv2d_64 (Conv2D)              (None, 16, 16, 256)  590080      dropout_30[0][0]                 
__________________________________________________________________________________________________
max_pooling2d_15 (MaxPooling2D) (None, 8, 8, 256)    0           conv2d_64[0][0]                  
__________________________________________________________________________________________________
conv2d_65 (Conv2D)              (None, 8, 8, 512)    1180160     max_pooling2d_15[0][0]           
__________________________________________________________________________________________________
dropout_31 (Dropout)            (None, 8, 8, 512)    0           conv2d_65[0][0]                  
__________________________________________________________________________________________________
conv2d_66 (Conv2D)              (None, 8, 8, 512)    2359808     dropout_31[0][0]                 
__________________________________________________________________________________________________
conv2d_transpose_12 (Conv2DTran (None, 16, 16, 256)  524544      conv2d_66[0][0]                  
__________________________________________________________________________________________________
concatenate_12 (Concatenate)    (None, 16, 16, 512)  0           conv2d_transpose_12[0][0]        
                                                                 conv2d_64[0][0]                  
__________________________________________________________________________________________________
conv2d_67 (Conv2D)              (None, 16, 16, 256)  1179904     concatenate_12[0][0]             
__________________________________________________________________________________________________
dropout_32 (Dropout)            (None, 16, 16, 256)  0           conv2d_67[0][0]                  
__________________________________________________________________________________________________
conv2d_68 (Conv2D)              (None, 16, 16, 256)  590080      dropout_32[0][0]                 
__________________________________________________________________________________________________
conv2d_transpose_13 (Conv2DTran (None, 32, 32, 128)  131200      conv2d_68[0][0]                  
__________________________________________________________________________________________________
concatenate_13 (Concatenate)    (None, 32, 32, 256)  0           conv2d_transpose_13[0][0]        
                                                                 conv2d_62[0][0]                  
__________________________________________________________________________________________________
conv2d_69 (Conv2D)              (None, 32, 32, 128)  295040      concatenate_13[0][0]             
__________________________________________________________________________________________________
dropout_33 (Dropout)            (None, 32, 32, 128)  0           conv2d_69[0][0]                  
__________________________________________________________________________________________________
conv2d_70 (Conv2D)              (None, 32, 32, 128)  147584      dropout_33[0][0]                 
__________________________________________________________________________________________________
conv2d_transpose_14 (Conv2DTran (None, 64, 64, 64)   32832       conv2d_70[0][0]                  
__________________________________________________________________________________________________
concatenate_14 (Concatenate)    (None, 64, 64, 128)  0           conv2d_transpose_14[0][0]        
                                                                 conv2d_60[0][0]                  
__________________________________________________________________________________________________
conv2d_71 (Conv2D)              (None, 64, 64, 64)   73792       concatenate_14[0][0]             
__________________________________________________________________________________________________
dropout_34 (Dropout)            (None, 64, 64, 64)   0           conv2d_71[0][0]                  
__________________________________________________________________________________________________
conv2d_72 (Conv2D)              (None, 64, 64, 64)   36928       dropout_34[0][0]                 
__________________________________________________________________________________________________
conv2d_transpose_15 (Conv2DTran (None, 128, 128, 32) 8224        conv2d_72[0][0]                  
__________________________________________________________________________________________________
concatenate_15 (Concatenate)    (None, 128, 128, 64) 0           conv2d_transpose_15[0][0]        
                                                                 conv2d_58[0][0]                  
__________________________________________________________________________________________________
conv2d_73 (Conv2D)              (None, 128, 128, 32) 18464       concatenate_15[0][0]             
__________________________________________________________________________________________________
dropout_35 (Dropout)            (None, 128, 128, 32) 0           conv2d_73[0][0]                  
__________________________________________________________________________________________________
conv2d_74 (Conv2D)              (None, 128, 128, 32) 9248        dropout_35[0][0]                 
__________________________________________________________________________________________________
conv2d_75 (Conv2D)              (None, 128, 128, 1)  33          conv2d_74[0][0]                  
==================================================================================================
Total params: 7,760,097
Trainable params: 7,760,097
Non-trainable params: 0
bf1o4zei

bf1o4zei1#

经过一些调试后,错误出现在发生器的一个输出形状中。我总是保证neg_labelslabels具有相同的形状,即使neg_images可能不在第零个轴上。
修复方法是将neg_labels的形状设置为neg_images在前三个轴和labels最后一个轴上的形状:

neg_images = neg_db['data'][neg_indices[j:j+skip]]
neg_labels = np.zeros((neg_images.shape[0],neg_images.shape[1],neg_images.shape[2],labels.shape[3])).astype(np.float32)

相关问题