tensorflow 如何读取tf.tensor中的数据

kcugc4gi  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(193)

我有视频字幕项目的数据集。训练的数据集管道构建为:

dataset = tf.data.Dataset.from_tensor_slices((videos , tf.ragged.constant(captions)))

我想读取进入培训步骤的所有batch_data,如下所示:

class VideoCaptioningModel(keras.Model):
.
.
.
    def train_step(self, batch_data):
        batch_img, batch_seq = batch_data
        batch_loss = 0
        batch_acc = 0
        
        print('batch_data=', batch_data)

.
.

输出为:

batch_data= (<tf.Tensor 'IteratorGetNext:0' shape=(None, 28, 1536) dtype=float32>, <tf.Tensor 'IteratorGetNext:1' shape=(None, None, 8) dtype=int64>)

我尝试使用print('batch_data=', batch_data.numpy()),但得到:

AttributeError: 'tuple' object has no attribute 'numpy'
b4lqfgs4

b4lqfgs41#

数据集由videoscaptions组成,数据集中的每个条目都是tuple。请参阅:

for x in dataset:
  tf.print(x[0]) # videos
  tf.print(x[1]) # captions

现在,注意你可以在Eager Execution模式下在tf.Tensor上调用.numpy(),但是元组没有这个属性。

tf.print('batch_data=', batch_data[0].numpy())
tf.print('batch_data=', batch_data[1].numpy())

相关问题