jax.lax.select vs jax.numpy.在哪里

oaxa6hgo  于 2023-01-02  发布在  其他
关注(0)|答案(1)|浏览(217)

我们来看看flax中的辍学实现:

def __call__(self, inputs, deterministic: Optional[bool] = None):
    """Applies a random dropout mask to the input.

    Args:
      inputs: the inputs that should be randomly masked.
      deterministic: if false the inputs are scaled by `1 / (1 - rate)` and
        masked, whereas if true, no mask is applied and the inputs are returned
        as is.

    Returns:
      The masked inputs reweighted to preserve mean.
    """
    deterministic = merge_param(
        'deterministic', self.deterministic, deterministic)

    if (self.rate == 0.) or deterministic:
      return inputs

    # Prevent gradient NaNs in 1.0 edge-case.
    if self.rate == 1.0:
      return jnp.zeros_like(inputs)

    keep_prob = 1. - self.rate
    rng = self.make_rng(self.rng_collection)
    broadcast_shape = list(inputs.shape)
    for dim in self.broadcast_dims:
      broadcast_shape[dim] = 1
    mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
    mask = jnp.broadcast_to(mask, inputs.shape)
    return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))

特别是,我对最后一行lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))很感兴趣,想知道为什么这里使用lax.select而不是:

return jnp.where(mask, inputs / keep_prob, 0)

或者更简单地说:

return mask * inputs / keep_prob
vs91vp4v

vs91vp4v1#

jnp.where基本上与lax.select相同,只是其输入更加灵活:例如,它会将输入广播为相同的shape或强制转换为相同的dtype,而lax.select需要更严格的输入匹配:

>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.arange(3)
# Implicit broadcasting
>>> jnp.where(x < 2, x[:, None], 0)
DeviceArray([[0, 0, 0],
             [1, 1, 0],
             [2, 2, 0]], dtype=int32)

>>> lax.select(x < 2, x[:, None], 0)
TypeError: select cases must have the same shapes, got [(), (3, 1)].
# Implicit type promotion
>>> jnp.where(x < 2, jnp.zeros(3), jnp.arange(3))
DeviceArray([0., 0., 2.], dtype=float32)

>>> lax.select(x < 2, jnp.zeros(3), jnp.arange(3))
TypeError: lax.select requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.where is a similar function that does automatic type promotion on inputs).

在库代码中,更严格的语义是有用的,因为它不会消除潜在的实现错误,也不会返回意外的输出,而是会大声抱怨。但在性能方面(尤其是JIT编译之后),两者基本上是等效的。
至于为什么亚麻开发者选择lax.select而不是乘以一个遮罩,我可以想到两个原因:
1.与掩码相乘受隐式类型提升语义的影响,并且与简单的select相比,要预测有问题的输出需要更多的思考,select是专门为预期操作设计的。
1.使用乘法会导致编译器将此操作视为乘法,但实际上并非如此。select是一种比乘法更窄更精确的操作,通过精确指定操作,编译器通常可以在更大程度上优化结果。

相关问题