我正在运行lightweightMMM的一个简单的demo example。我使用lambda函数进行缩放,如下所示:
media_data_train_a = media_data[:split_point, :]
lambda_operation = lambda x: jnp.mean(x[x > 0])
media_scaler = preprocessing.CustomScaler(divide_operation=lambda_operation)
media_data_train = np.array(media_scaler.fit_transform(media_data_train_a))
我已经尝试了相同的代码与示例数据,那里它的工作很好。但是,当我尝试它对我自己的数据集,我得到以下错误时,执行最后一行:NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[41]) See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
你知道我的数据返回此错误的原因吗?
两个数据集都是numpy数组,我也读了建议的jax.doc,认为错误是由于使用了非静态数组,这个问题的常见解决方案是使用jax.numpy.where()
的三参数版本。
如何根据jax.numpy.where()
的JIT兼容的三参数版本实现lambda函数的逻辑?
1条答案
按热度按时间cotxawn71#
问题来自
y = x[x > 0]
返回一个数组,该数组的大小取决于x
中的值;另一种说法是y
具有动态形状,JAX的转换模型目前不支持动态形状的数组,结果就是您看到的错误。通常你可以通过重新表达你的计算来解决这个问题,这样它就不依赖于构造动态形状的中间数组。在这种情况下,类似下面的方法应该可以工作: