我们来看看flax中的辍学实现:
def __call__(self, inputs, deterministic: Optional[bool] = None):
"""Applies a random dropout mask to the input.
Args:
inputs: the inputs that should be randomly masked.
deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
masked, whereas if true, no mask is applied and the inputs are returned
as is.
Returns:
The masked inputs reweighted to preserve mean.
"""
deterministic = merge_param(
'deterministic', self.deterministic, deterministic)
if (self.rate == 0.) or deterministic:
return inputs
# Prevent gradient NaNs in 1.0 edge-case.
if self.rate == 1.0:
return jnp.zeros_like(inputs)
keep_prob = 1. - self.rate
rng = self.make_rng(self.rng_collection)
broadcast_shape = list(inputs.shape)
for dim in self.broadcast_dims:
broadcast_shape[dim] = 1
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
特别是,我对最后一行lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
很感兴趣,想知道为什么这里使用lax.select
而不是:
return jnp.where(mask, inputs / keep_prob, 0)
或者更简单地说:
return mask * inputs / keep_prob
1条答案
按热度按时间vs91vp4v1#
jnp.where
基本上与lax.select
相同,只是其输入更加灵活:例如,它会将输入广播为相同的shape或强制转换为相同的dtype,而lax.select
需要更严格的输入匹配:在库代码中,更严格的语义是有用的,因为它不会消除潜在的实现错误,也不会返回意外的输出,而是会大声抱怨。但在性能方面(尤其是JIT编译之后),两者基本上是等效的。
至于为什么亚麻开发者选择
lax.select
而不是乘以一个遮罩,我可以想到两个原因:1.与掩码相乘受隐式类型提升语义的影响,并且与简单的
select
相比,要预测有问题的输出需要更多的思考,select
是专门为预期操作设计的。1.使用乘法会导致编译器将此操作视为乘法,但实际上并非如此。
select
是一种比乘法更窄更精确的操作,通过精确指定操作,编译器通常可以在更大程度上优化结果。