keras 如何定义硬Sigmoid

6tr1vspr  于 2023-11-19  发布在  其他
关注(0)|答案(5)|浏览(133)

我正在使用keras开发深度网络。有一个激活“硬sigmoid”。它的数学定义是什么?
我知道什么是Sigmoid。有人在Quora上问了类似的问题:https://www.quora.com/What-is-hard-sigmoid-in-artificial-neural-networks-Why-is-it-faster-than-standard-sigmoid-Are-there-any-disadvantages-over-the-standard-sigmoid
但我找不到精确的数学定义?

gopyfrb3

gopyfrb31#

由于Keras同时支持Tensorflow和Theano,因此每个后端的具体实现可能会有所不同-我将仅介绍Theano。对于Theano后端,Keras使用T.nnet.hard_sigmoid,这反过来又是线性近似的标准sigmoid:

slope = tensor.constant(0.2, dtype=out_dtype)
shift = tensor.constant(0.5, dtype=out_dtype)
x = (x * slope) + shift
x = tensor.clip(x, 0, 1)

字符串
例如:max(0, min(1, x*0.2 + 0.5))

8tntrjer

8tntrjer2#

作为参考,hard sigmoid function在不同的地方可能会有不同的定义。在Courbariaux等人。2016 [1]中,它被定义为:
σ是“硬sigmoid”函数:σ(x)= clip((x + 1)/2,0,1)= max(0,min(1,(x + 1)/2))
目的是提供一个概率值(因此将其限制在01之间),用于神经网络参数的随机二值化(例如权重、激活、梯度)。使用从硬sigmoid函数返回的概率p = σ(x)将参数x设置为+1,概率为p,或者-1的概率为1-p
[1][https://arxiv.org/abs/1602.02830](https://arxiv.org/abs/1602.02830)-“Binarized Neural Networks:Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1”,Matthieu Courbariaux,Itay Hubara,丹尼尔Soudry,Ran El-Yaniv,Yoclave Bengio,(提交于2016年2月9日(v1),最后修订于2016年3月17日(此版本,v3))

s5a0g9ez

s5a0g9ez3#

硬sigmoid通常是logistic sigmoid函数的分段线性近似。根据您想要保留原始sigmoid的属性,您可以使用不同的近似。
我个人喜欢将函数保持为零,即σ(0) = 0.5(移位)和σ'(0) = 0.25(斜率)。

def hard_sigmoid(x):
    return np.maximum(0, np.minimum(1, (x + 2) / 4))

字符串

gk7wooem

gk7wooem4#

截至2023年10月,Tensorflow Keras中使用的定义似乎略有变化。tf.keras.activations.hard_sigmoid的文档指出:
硬S形激活,定义为:

if x < -2.5: return 0
if x > 2.5: return 1
if -2.5 <= x <= 2.5: return 0.2 * x + 0.5

字符串
下面是一些绘制函数的代码。

import tensorflow as tf
import numpy
import pandas

x = numpy.linspace(-10, 10, 100) 
a = tf.constant(x, dtype = tf.float32)
h = tf.keras.activations.hard_sigmoid(a).numpy()
s = tf.keras.activations.sigmoid(a).numpy()

df = pandas.DataFrame({
    'input': x,
    'sigmoid': s,
    'hardsigmoid': h,
}).set_index('input')
ax = df.plot()
ax.figure.savefig('hardsigmoid-keras.png')


下面是使用tensorflow 2.14.0运行时的输出。


的数据

uinbv5nw

uinbv5nw5#

clip((x + 1)/2, 0, 1)

字符串
用编码术语来说:

max(0, min(1, (x + 1)/2))

相关问题