我在使用flax.jax_utils.prefetch_to_device
实现下面的简单功能时遇到了问题。我正在加载SIFT 1 M数据集,并将数组转换为jnp数组。
然后我想预取128-dim数组的迭代器。
import tensorflow_datasets as tfds
import tensorflow as tf
import jax
import jax.numpy as jnp
import itertools
import jax.dlpack
import jax.tools.colab_tpu
import flax
def _sift1m_iter():
def prepare_tf_data(xs):
def _prepare(x):
dl_arr = tf.experimental.dlpack.to_dlpack(x)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
return jax_arr
return jax.tree_util.tree_map(_prepare, xs['embedding'])
ds = tfds.load('sift1m', split='database')
it = map(prepare_tf_data, ds)
#it = flax.jax_utils.prefetch_to_device(it, 2) => this causes an error
return it
然而,当我运行这段代码时,我得到一个错误:
ValueError: len(shards) = 128 must equal len(devices) = 1.
我在一个只有CPU的设备上运行这个,但是从错误来看,我传递到prefetch_to_device
的数据的形状似乎是错误的。
1条答案
按热度按时间8fq7wneg1#
_prepare(x)
函数的输出应该具有[num_devices, batch_size]
的形状。在你的例子中,假设你有一个GPU,它的形状应该是
[1, 128]
。看看如何在这里做到这一点。