是否有一个PyTorch函数等效于numpy.select()?

ergxz8rk  于 2023-03-18  发布在  其他
关注(0)|答案(1)|浏览(121)

使用Numpy的示例:

import numpy as np
x = np.arange(-5, 6)
y = np.select([x<-2, x>2], [x**3, x**2], 5)

我可以使用简单的条件语句,但我需要一些东西来并行处理GPU上的大型数组。

rt4zxlrg

rt4zxlrg1#

这个NumPy表达式看起来等效于递归调用torch.where

y = torch.where(x<-2, x**3, torch.where(x>2, x**2, 5))

这意味着您可以通过以下方式自行实施:

def select(condlist, choicelist, default=0):
    o = default
    for c, v in reversed(list(zip(condlist, choicelist))):
        o = torch.where(c, v, o)
    return o

如果你想用递归的方式来写它,你可以用:

def select(c, v, d=0):
    _c, _v = c.pop(), v.pop()
    r = select(c, v, d) if len(c) else d
    return torch.where(_c, _v, r)

两者都将产生相同的结果:

>>> select([x<-2, x>2], [x**3, x**2], 5)
tensor([-125,  -64,  -27,    5,    5,    5,    5,    5,    9,   16,   25])

相关问题