我有一个函数,基本上会示例化一个巨大的数组,并做其他事情。我在TPU上运行我的代码,所以基本上我的内存是有限的。
如何在CPU上执行特定功能?
如果我这样做:
y = jax.device_put(my_function(), device=jax.devices("cpu")[0])
我猜my_function()
首先在TPU上执行,结果放在CPU上,这给了我内存错误。
在代码开头使用jax.config.update('jax_platform_name', 'cpu')
似乎没有效果。
另外请注意,我不能修改my_function()
谢谢你!
2条答案
按热度按时间4xrmg8kj1#
我在这里猜一猜。我也不能运行它,所以你可能要摆弄它
看这里和这里的文件。
eagi6jfj2#
要直接指定执行函数的设备,请使用
jax.jit
的device
参数。例如(使用GPU运行时,因为它是我目前可以访问的加速器):这也可以通过调用点周围的
jax.default_device
装饰器来控制: