使用Numpy的示例:
import numpy as np x = np.arange(-5, 6) y = np.select([x<-2, x>2], [x**3, x**2], 5)
我可以使用简单的条件语句,但我需要一些东西来并行处理GPU上的大型数组。
rt4zxlrg1#
这个NumPy表达式看起来等效于递归调用torch.where:
torch.where
y = torch.where(x<-2, x**3, torch.where(x>2, x**2, 5))
这意味着您可以通过以下方式自行实施:
def select(condlist, choicelist, default=0): o = default for c, v in reversed(list(zip(condlist, choicelist))): o = torch.where(c, v, o) return o
如果你想用递归的方式来写它,你可以用:
def select(c, v, d=0): _c, _v = c.pop(), v.pop() r = select(c, v, d) if len(c) else d return torch.where(_c, _v, r)
两者都将产生相同的结果:
>>> select([x<-2, x>2], [x**3, x**2], 5) tensor([-125, -64, -27, 5, 5, 5, 5, 5, 9, 16, 25])
1条答案
按热度按时间rt4zxlrg1#
这个NumPy表达式看起来等效于递归调用
torch.where
:这意味着您可以通过以下方式自行实施:
如果你想用递归的方式来写它,你可以用:
两者都将产生相同的结果: