numpy 将128维数组的迭代器预取到设备

vptzau2j  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(81)

我在使用flax.jax_utils.prefetch_to_device实现下面的简单功能时遇到了问题。我正在加载SIFT 1 M数据集,并将数组转换为jnp数组。
然后我想预取128-dim数组的迭代器。

import tensorflow_datasets as tfds
import tensorflow as tf
import jax
import jax.numpy as jnp
import itertools
import jax.dlpack
import jax.tools.colab_tpu
import flax

def _sift1m_iter():
    def prepare_tf_data(xs):
        def _prepare(x):
            dl_arr = tf.experimental.dlpack.to_dlpack(x)
            jax_arr = jax.dlpack.from_dlpack(dl_arr)
            return jax_arr

        return jax.tree_util.tree_map(_prepare, xs['embedding'])

    ds = tfds.load('sift1m', split='database')
    it = map(prepare_tf_data, ds)
    #it = flax.jax_utils.prefetch_to_device(it, 2)  => this causes an error
    return it

然而,当我运行这段代码时,我得到一个错误:

ValueError: len(shards) = 128 must equal len(devices) = 1.

我在一个只有CPU的设备上运行这个,但是从错误来看,我传递到prefetch_to_device的数据的形状似乎是错误的。

8fq7wneg

8fq7wneg1#

_prepare(x)函数的输出应该具有[num_devices, batch_size]的形状。
在你的例子中,假设你有一个GPU,它的形状应该是[1, 128]
看看如何在这里做到这一点。

相关问题