问题
如何编写一个Map函数来生成TensorFlow数据集,其中每行都是一个多列Tensor?
问题
这是一个数据集。
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([
tf.constant([0.14375, 0.0437018, 0.97083336], dtype=np.float32),
tf.constant([0.14583333, 0.24164525, 0.57916665], dtype=np.float32),
tf.constant([0.6, 0.5244216, 0.8541667], dtype=np.float32),
])
for d in dataset:
print(d)
-----
tf.Tensor([0.14375 0.0437018 0.97083336], shape=(3,), dtype=float32)
tf.Tensor([0.14583333 0.24164525 0.57916665], shape=(3,), dtype=float32)
tf.Tensor([0.6 0.5244216 0.8541667], shape=(3,), dtype=float32)
预期结果
应用f
并得到预期的数据集,其中每行是(3,)
形状的Tensor。这个形状的Tensor是我需要实现的。
def f(x):
return x * 2
for d in dataset.map(f):
print(d)
-----
tf.Tensor([0.2875 0.0874036 1.9416667], shape=(3,), dtype=float32)
tf.Tensor([0.29166666 0.4832905 1.1583333 ], shape=(3,), dtype=float32)
tf.Tensor([1.2 1.0488431 1.7083334], shape=(3,), dtype=float32)
意外结果
需要使g
返回形状(3,)
的Tensor。当前g
返回的数据集每行都是形状()
中Tensor的元组。
def g(x):
return x[0], x[1] * 2, x[2] * 3
for d in dataset.map(g):
print(d)
-----
<tf.Tensor: shape=(), dtype=float32, numpy=0.14375>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0874036>, <tf.Tensor: shape=(), dtype=float32, numpy=2.9125001>)
(<tf.Tensor: shape=(), dtype=float32, numpy=0.14583333>, <tf.Tensor: shape=(), dtype=float32, numpy=0.4832905>, <tf.Tensor: shape=(), dtype=float32, numpy=1.7375>)
(<tf.Tensor: shape=(), dtype=float32, numpy=0.6>, <tf.Tensor: shape=(), dtype=float32, numpy=1.0488431>, <tf.Tensor: shape=(), dtype=float32, numpy=2.5625>)
尝试使g
返回(3,)
形状的Tensor,但没有成功。
不能使用tf.constant
作为返回值。也不能在tf图形/函数内使用numpy。
def g(x):
return tf.constant([x[0], x[1] * 2, x[2] * 3])
for d in dataset.map(g):
print(d)
-----
...
TypeError: Expected any non-tensor type, but got a tensor instead.
我如何修复g
,使应用的结果产生一个数据集,其中每行是形状为shape=(3,)
的单个Tensor,具有所有三个浮点数,如:
tf.Tensor([0.2875 0.0874036 1.9416667], shape=(3,), dtype=float32)
如预期结果所示,当前生成的每个浮点值没有一个Tensor,如:
<tf.Tensor: shape=(), dtype=float32, numpy=0.14375>, <tf.Tensor: shape=(), dtype=float32, numpy=0.0874036>, ...
更新
看起来我可以使用flat_map
,但显然它是没有效率的 Package 到额外的名单和解开与flat_map
。
不确定为什么我不能用map
实现同样的效果。
def g(x):
return tf.data.Dataset.from_tensor_slices([[x[0]*1, x[1]*2, x[2] * 3]])
for d in dataset.flat_map(g):
print(d)
---
tf.Tensor([0.14375 0.0874036 2.9125001], shape=(3,), dtype=float32)
tf.Tensor([0.14583333 0.4832905 1.7375 ], shape=(3,), dtype=float32)
tf.Tensor([0.6 1.0488431 2.5625 ], shape=(3,), dtype=float32)
1条答案
按热度按时间kqlmhetl1#
看起来
flat_map
是条路。