我还在学习tensorflow 和Keras,我怀疑这个问题有一个非常简单的答案,我只是错过了由于缺乏熟悉。
我有一个PrefetchDataset
对象:
> print(tf_test)
$ <PrefetchDataset shapes: ((None, 99), (None,)), types: (tf.float32, tf.int64)>
...由特征和目标组成。我可以使用for
循环对其进行迭代:
> for example in tf_test:
> print(example[0].numpy())
> print(example[1].numpy())
> exit()
$ [[-0.31 -0.94 -1.12 ... 0.18 -0.27]
[-0.22 -0.54 -0.14 ... 0.33 -0.55]
[-0.60 -0.02 -1.41 ... 0.21 -0.63]
...
[-0.03 -0.91 -0.12 ... 0.77 -0.23]
[-0.76 -1.48 -0.15 ... 0.38 -0.35]
[-0.55 -0.08 -0.69 ... 0.44 -0.36]]
[0 0 1 0 1 0 0 0 1 0 1 1 0 1 0 0 0
...
0 1 1 0]
然而,这是非常慢的,我想做的是访问对应于类标签的Tensor,并将其转化为一个numpy数组,或列表,或任何类型的可迭代对象,以输入scikit-learn的分类报告和/或混淆矩阵:
> y_pred = model.predict(tf_test)
> print(y_pred)
$ [[0.01]
[0.14]
[0.00]
...
[0.32]
[0.03]
[0.00]]
> y_pred_list = [int(x[0]) for x in y_pred] # assumes value >= 0.5 is positive prediction
> y_true = [] # what I need help with
> print(sklearn.metrics.confusion_matrix(y_true, y_pred_list)
...或访问数据,以便将其用于tensorflow 的混淆矩阵:
> labels = [] # what I need help with
> predictions = y_pred_list # could we just use a tensor?
> print(tf.math.confusion_matrix(labels, predictions)
在这两种情况下,从原始对象中获取目标数据的一般能力在计算上并不昂贵,这将非常有帮助(并且可能有助于我的基本直觉:tensorflow 和Keras)。
如有任何建议,我们将不胜感激。
8条答案
按热度按时间up9lanfz1#
你可以用
list(ds)
将它转换成一个列表,然后用tf.data.Dataset.from_tensor_slices(list(ds))
将它重新编译成一个普通的Dataset。从那里你的噩梦又开始了,但至少这是一个其他人以前经历过的噩梦。请注意,对于更复杂的数据集(例如嵌套字典),在调用
list(ds)
后需要更多的预处理,但这应该适用于您所问的示例。这远不是一个令人满意的答案,但不幸的是,该类完全没有文档记录,标准的Dataset技巧都不起作用。
4ngedf3f2#
可以使用
map
从每个(input, label)
对中选择输入或标签,并将其转换为列表:7d7tgy0s3#
您可以通过循环PrefetchDataset(在我的示例中是train_dataset)来生成列表;
因此,您可以通过使用索引分别访问每个示例和标签;
您也可以使用Pandas将其转换为2列数据框
然后,如果您想将列表转换回PrefetchFataset,您可以简单地使用;
你可以检查它是否能用这个
vqlkdk9b4#
如果要保留批次或将所有标签提取为单个Tensor,可以使用以下函数:
nfs0ujit5#
这是Dataset.prefetch()方法返回的类,它是Dataset的子类。
如果通过将ReadConfig传递给生成器来设置skip_prefetch=Ture,则返回的类型将改为_OptionsDataset。
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch
bvk5enib6#
我的对象也有类似的问题,如下所示:
我设法从一个批次中提取特征和标签,方法是:
hivapdat7#
可以使用map()函数一次性迭代
ratings.map(lambda x: x["feature name"])
kkbh8khc8#
如果你使用批处理创建了你的
tf.data.Dataset
,并且你想要两个单独的numpy数组,这将把每个列表的列表连接成一个数组。