修复了Python中jax.numpy.where()的非具体布尔索引错误

nbnkbykc  于 2022-12-17  发布在  Python
关注(0)|答案(1)|浏览(172)

我正在运行lightweightMMM的一个简单的demo example。我使用lambda函数进行缩放,如下所示:

media_data_train_a = media_data[:split_point, :]
lambda_operation = lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=lambda_operation)
media_data_train = np.array(media_scaler.fit_transform(media_data_train_a))

我已经尝试了相同的代码与示例数据,那里它的工作很好。但是,当我尝试它对我自己的数据集,我得到以下错误时,执行最后一行:
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[41]) See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
你知道我的数据返回此错误的原因吗?
两个数据集都是numpy数组,我也读了建议的jax.doc,认为错误是由于使用了非静态数组,这个问题的常见解决方案是使用jax.numpy.where()的三参数版本。
如何根据jax.numpy.where()的JIT兼容的三参数版本实现lambda函数的逻辑?

cotxawn7

cotxawn71#

问题来自y = x[x > 0]返回一个数组,该数组的大小取决于x中的值;另一种说法是y具有动态形状,JAX的转换模型目前不支持动态形状的数组,结果就是您看到的错误。
通常你可以通过重新表达你的计算来解决这个问题,这样它就不依赖于构造动态形状的中间数组。在这种情况下,类似下面的方法应该可以工作:

lambda_operation = lambda x: jnp.where(x > 0, x, 0.0).sum() / (x > 0).sum()

相关问题