我有视频字幕项目的数据集。训练的数据集管道构建为:
dataset = tf.data.Dataset.from_tensor_slices((videos , tf.ragged.constant(captions)))
我想读取进入培训步骤的所有batch_data,如下所示:
class VideoCaptioningModel(keras.Model):
.
.
.
def train_step(self, batch_data):
batch_img, batch_seq = batch_data
batch_loss = 0
batch_acc = 0
print('batch_data=', batch_data)
.
.
输出为:
batch_data= (<tf.Tensor 'IteratorGetNext:0' shape=(None, 28, 1536) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None, None, 8) dtype=int64>)
我尝试使用print('batch_data=', batch_data.numpy())
,但得到:
AttributeError: 'tuple' object has no attribute 'numpy'
1条答案
按热度按时间b4lqfgs41#
数据集由
videos
和captions
组成,数据集中的每个条目都是tuple
。请参阅:现在,注意你可以在
Eager Execution
模式下在tf.Tensor
上调用.numpy()
,但是元组没有这个属性。