TensorFlow数据集-如何使Map函数将多列作为一个Tensor返回

7hiiyaii  于 2023-03-03  发布在  其他
关注(0)|答案(1)|浏览(152)

问题

如何编写一个Map函数来生成TensorFlow数据集,其中每行都是一个多列Tensor?

问题

这是一个数据集。

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([
    tf.constant([0.14375, 0.0437018, 0.97083336], dtype=np.float32),
    tf.constant([0.14583333, 0.24164525, 0.57916665], dtype=np.float32),
    tf.constant([0.6, 0.5244216, 0.8541667], dtype=np.float32),
])
for d in dataset:
    print(d)
-----
tf.Tensor([0.14375    0.0437018  0.97083336], shape=(3,), dtype=float32)
tf.Tensor([0.14583333 0.24164525 0.57916665], shape=(3,), dtype=float32)
tf.Tensor([0.6       0.5244216 0.8541667], shape=(3,), dtype=float32)

预期结果

应用f并得到预期的数据集,其中每行是(3,)形状的Tensor。这个形状的Tensor是我需要实现的。

def f(x):
    return x * 2

for d in dataset.map(f):
    print(d)
-----
tf.Tensor([0.2875    0.0874036 1.9416667], shape=(3,), dtype=float32)
tf.Tensor([0.29166666 0.4832905  1.1583333 ], shape=(3,), dtype=float32)
tf.Tensor([1.2       1.0488431 1.7083334], shape=(3,), dtype=float32)

意外结果

需要使g返回形状(3,)的Tensor。当前g返回的数据集每行都是形状()中Tensor的元组。

def g(x):
    return x[0], x[1] * 2, x[2] * 3

for d in dataset.map(g):
    print(d)
-----
<tf.Tensor: shape=(), dtype=float32, numpy=0.14375>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0874036>, <tf.Tensor: shape=(), dtype=float32, numpy=2.9125001>)
(<tf.Tensor: shape=(), dtype=float32, numpy=0.14583333>, <tf.Tensor: shape=(), dtype=float32, numpy=0.4832905>, <tf.Tensor: shape=(), dtype=float32, numpy=1.7375>)
(<tf.Tensor: shape=(), dtype=float32, numpy=0.6>, <tf.Tensor: shape=(), dtype=float32, numpy=1.0488431>, <tf.Tensor: shape=(), dtype=float32, numpy=2.5625>)

尝试使g返回(3,)形状的Tensor,但没有成功。
不能使用tf.constant作为返回值。也不能在tf图形/函数内使用numpy。

def g(x):
    return tf.constant([x[0], x[1] * 2, x[2] * 3])

for d in dataset.map(g):
    print(d)
-----
...
TypeError: Expected any non-tensor type, but got a tensor instead.

我如何修复g,使应用的结果产生一个数据集,其中每行是形状为shape=(3,)的单个Tensor,具有所有三个浮点数,如:

tf.Tensor([0.2875    0.0874036 1.9416667], shape=(3,), dtype=float32)

如预期结果所示,当前生成的每个浮点值没有一个Tensor,如:

<tf.Tensor: shape=(), dtype=float32, numpy=0.14375>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0874036>, ...

更新

看起来我可以使用flat_map,但显然它是没有效率的 Package 到额外的名单和解开与flat_map
不确定为什么我不能用map实现同样的效果。

def g(x):
    return tf.data.Dataset.from_tensor_slices([[x[0]*1, x[1]*2, x[2] * 3]])

for d in dataset.flat_map(g):
    print(d)
---
tf.Tensor([0.14375   0.0874036 2.9125001], shape=(3,), dtype=float32)
tf.Tensor([0.14583333 0.4832905  1.7375    ], shape=(3,), dtype=float32)
tf.Tensor([0.6       1.0488431 2.5625   ], shape=(3,), dtype=float32)
kqlmhetl

kqlmhetl1#

看起来flat_map是条路。

def g(x):
    return tf.data.Dataset.from_tensors([x[0]*1, x[1]*2, x[2] * 3])

for d in dataset.flat_map(g):
    print(d)
---
tf.Tensor([0.14375   0.0874036 2.9125001], shape=(3,), dtype=float32)
tf.Tensor([0.14583333 0.4832905  1.7375    ], shape=(3,), dtype=float32)
tf.Tensor([0.6       1.0488431 2.5625   ], shape=(3,), dtype=float32)

相关问题