python 在Jax中专门在CPU上执行函数

yk9xbfzb  于 2022-11-27  发布在  Python
关注(0)|答案(2)|浏览(140)

我有一个函数,基本上会示例化一个巨大的数组,并做其他事情。我在TPU上运行我的代码,所以基本上我的内存是有限的。
如何在CPU上执行特定功能?
如果我这样做:

y = jax.device_put(my_function(), device=jax.devices("cpu")[0])

我猜my_function()首先在TPU上执行,结果放在CPU上,这给了我内存错误。
在代码开头使用jax.config.update('jax_platform_name', 'cpu')似乎没有效果。
另外请注意,我不能修改my_function()
谢谢你!

4xrmg8kj

4xrmg8kj1#

我在这里猜一猜。我也不能运行它,所以你可能要摆弄它

with jax.default_device(jax.devices("cpu")[0]):
    y = my_function()

看这里和这里的文件。

eagi6jfj

eagi6jfj2#

要直接指定执行函数的设备,请使用jax.jitdevice参数。例如(使用GPU运行时,因为它是我目前可以访问的加速器):

import jax

gpu_device = jax.devices('gpu')[0]
cpu_device = jax.devices('cpu')[0]

def my_function(x):
  return x.sum()

x = jax.numpy.arange(10)

x_gpu = jax.jit(my_function, device=gpu_device)(x)
print(x_tpu.device())
# gpu:0

x_cpu = jax.jit(my_function, device=cpu_device)(x)
print(x_cpu.device())
# TFRT_CPU_0

这也可以通过调用点周围的jax.default_device装饰器来控制:

with jax.default_device(cpu_device):
  print(jax.jit(my_function)(x).device())
  # TFRT_CPU_0

with jax.default_device(gpu_device):
  print(jax.jit(my_function)(x).device())
  # gpu:0

相关问题