pytorch 如何获得范围(0,0.5)内的S形函数的大多数输出?

dy1byipe  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(146)

我在最后一层使用了S形函数。我们知道S形函数将网络的输出限制在范围(0,1)内。我希望大多数输出都在范围(0,0.5)内,只有很少的输出在范围[0.5,1)内。我如何在Pytorch中做到这一点以获得所需的输出呢?
下面的Pytorch代码片段与此问题相关:

class Generator(nn.Module):
def __init__(self):
    super(Generator, self).__init__()
    #
    def block(in_feat, out_feat, normalize=True):
        layers = [nn.Linear(in_features=in_feat, out_features=out_feat)]
        if normalize:
            layers.append(nn.BatchNorm1d(out_feat))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers
    # now we can use this function like below:
    self.model = nn.Sequential(*block(params.input_dim_generator, 500, normalize=False),
                               *block(500, 350),
                               *block(350, 256),
                               nn.Linear(256, 564),
                               nn.Sigmoid())

# forward

    def forward(self, old_vector, z):
        vector_app = torch.cat((old_vector, z), dim=1)
        new_vector = self.model(vector_app)
        new_result = torch.max(new_vector, old_vector).float()
        return new_result

z是范围(0,1)中的随机噪声向量,而old_vector是二进制向量(值为0或1)。
该模型输出是:

torch.tensor([0.5167, 0.5281, 0.5804, 0.4372, 1.0000, 1.0000, 1.0000, 0.5501, 1.0000,
        0.6154, 1.0000, 1.0000, 0.4699, 0.5536, 0.5005, 0.4318, 0.5302, 0.4830,
        0.5404, 0.3597, 0.4639, 0.5885, 0.4997, 0.5881, 0.5046, 0.5670, 0.3977,
        0.5186, 0.5859, 0.5398, 0.3954, 0.4839, 0.3310, 0.5208, 0.5420, 0.5056,
        0.5022, 0.6316, 0.6185, 0.5142, 0.5536, 0.4988, 0.5250, 0.4813, 0.5150,
        0.4080, 1.0000, 1.0000, 1.0000, 0.6054, 0.4766, 0.4423, 0.4520, 0.4816,
        0.5159, 0.4582, 1.0000, 0.4550, 0.4956, 1.0000, 0.5934, 1.0000, 0.4809,
        0.5512, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.4024, 0.4822, 1.0000,
        0.5310, 1.0000, 0.5127, 1.0000, 0.5441, 0.5063, 1.0000, 0.5511, 0.5544,
        1.0000, 0.4585, 0.5211, 0.5758, 0.4355, 1.0000, 0.5297, 0.4582, 0.4170,
        1.0000, 1.0000, 0.5257, 0.4194, 0.3583, 0.5087, 0.5936, 0.4851, 0.5697,
        0.4261, 0.4736, 0.4551, 1.0000, 0.5667, 0.5650, 1.0000, 0.5069, 0.5901,
        0.4980, 0.5184, 1.0000, 1.0000, 0.5435, 1.0000, 1.0000, 1.0000, 1.0000,
        0.4521, 1.0000, 0.4509, 1.0000, 0.5067, 1.0000, 0.4152, 0.5034, 0.5735,
        0.4040, 1.0000, 0.4492, 1.0000, 0.4405, 1.0000, 1.0000, 0.5667, 0.5639,
        0.4013, 0.4357, 0.4437, 0.4510, 0.4225, 0.5091, 0.5057, 1.0000, 0.5237,
        0.5098, 1.0000, 0.4216, 0.5242, 0.5335, 0.3916, 0.4938, 1.0000, 0.4070,
        0.5210, 1.0000, 1.0000, 0.4050, 0.3960, 0.5750, 0.4906, 0.4991, 1.0000,
        0.3149, 0.2949, 1.0000, 0.4515, 0.3627, 0.4348, 0.3887, 0.5807, 0.5787,
        0.5781, 1.0000, 1.0000, 1.0000, 1.0000, 0.4919, 1.0000, 1.0000, 0.5554,
        0.5515, 1.0000, 0.5472, 0.3342, 0.5705, 0.5076, 0.6348, 0.4436, 0.4683,
        0.4228, 0.6506, 0.4540, 0.5333, 0.4512, 0.6037, 0.5173, 1.0000, 0.4466,
        0.5644, 0.5565, 0.5141, 0.4771, 0.5822, 0.4888, 1.0000, 0.6331, 0.6435,
        1.0000, 0.5012, 1.0000, 0.4864, 1.0000, 0.4994, 0.4326, 0.4347, 0.3606,
        0.5829, 0.5229, 1.0000, 0.5992, 0.5883, 0.4825, 0.6254, 0.4951, 0.4285,
        0.4982, 1.0000, 0.5847, 0.4131, 0.5194, 0.5270, 0.4856, 0.6182, 0.5578,
        1.0000, 0.5460, 0.5023, 0.6279, 0.5727, 0.5997, 0.4903, 0.5633, 0.5070,
        0.5013, 1.0000, 0.4179, 0.5529, 0.6254, 0.5767, 0.3939, 0.5791, 0.4936,
        0.4714, 0.5150, 0.5717, 0.4570, 0.4463, 0.5493, 0.5179, 1.0000, 0.5682,
        0.5451, 0.5266, 0.5571, 1.0000, 1.0000, 0.5506, 0.4710, 0.5951, 1.0000,
        0.5027, 1.0000, 1.0000, 0.4960, 0.6269, 0.4817, 1.0000, 0.4059, 0.4787,
        0.4419, 0.5479, 0.4830, 0.4709, 0.6106, 0.6154, 0.3958, 0.6434, 0.4626,
        0.5954, 0.5083, 0.5121, 1.0000, 0.5139, 1.0000, 0.5428, 1.0000, 0.5278,
        0.5255, 0.5854, 0.4400, 0.4774, 0.4431, 0.4871, 0.3854, 0.6217, 0.5562,
        0.4461, 0.5191, 0.5654, 0.4428, 0.5503, 0.5742, 1.0000, 0.4899, 1.0000,
        0.5229, 0.5428, 0.4285, 0.3038, 0.3029, 0.5145, 0.6747, 0.5685, 0.5268,
        0.4888, 0.6431, 0.5308, 0.6249, 0.4531, 0.5631, 0.4498, 0.4465, 0.5125,
        0.5610, 1.0000, 0.5033, 0.5517, 1.0000, 0.4625, 0.5095, 1.0000, 0.3415,
        0.4749, 1.0000, 0.4567, 1.0000, 0.4417, 0.5623, 1.0000, 0.4780, 0.4218,
        1.0000, 0.5474, 0.6514, 0.5725, 0.4219, 0.5303, 0.3375, 0.5710, 0.5507,
        0.3698, 0.4902, 0.6082, 0.5212, 0.5606, 0.5320, 0.4893, 0.3831, 0.4605,
        0.5409, 0.4605, 0.5774, 0.5709, 0.5020, 0.5771, 0.4032, 0.5832, 0.4454,
        0.4572, 0.4651, 0.4752, 0.5786, 0.4700, 0.3398, 0.4143, 0.4413, 0.4020,
        0.6390, 0.5165, 0.4871, 0.6229, 0.4915, 1.0000, 0.4780, 0.5900, 0.4847,
        0.4583, 0.5889, 0.4291, 0.4095, 0.5258, 1.0000, 0.4875, 1.0000, 0.5174,
        0.4302, 1.0000, 0.5058, 0.5917, 0.5395, 0.3915, 0.4775, 0.4688, 0.4860,
        0.4869, 0.4189, 1.0000, 0.6453, 0.4652, 0.5106, 0.4336, 0.4959, 0.5144,
        1.0000, 1.0000, 0.4382, 0.5917, 1.0000, 0.5123, 0.4299, 0.5447, 1.0000,
        0.5316, 0.4145, 0.5741, 1.0000, 0.4581, 0.5953, 1.0000, 0.4909, 0.3703,
        0.3851, 0.5324, 1.0000, 0.6660, 1.0000, 0.5687, 0.4825, 0.5081, 0.5052,
        0.6288, 0.5371, 0.4286, 1.0000, 0.6535, 0.5556, 0.5390, 0.3320, 1.0000,
        0.6431, 0.5405, 1.0000, 0.3641, 0.4390, 0.6196, 0.4720, 0.5114, 0.4844,
        0.4184, 0.6269, 1.0000, 0.4077, 0.3950, 0.4502, 1.0000, 0.4417, 0.4329,
        0.5803, 0.4967, 0.5248, 0.5182, 0.4417, 0.4066, 0.6219, 0.3435, 1.0000,
        0.4680, 1.0000, 0.5403, 0.4570, 1.0000, 0.5805, 1.0000, 0.5796, 0.5100,
        0.6487, 0.4752, 0.4579, 0.6026, 0.5964, 0.5842, 0.3423, 0.5475, 0.4467,
        0.4494, 0.4782, 0.6054, 0.4499, 0.4691, 0.4700, 0.5006, 0.5895, 0.3947,
        0.5517, 0.4240, 0.5286, 0.4796, 0.5116, 0.5696, 0.4369, 0.4761, 0.5444,
        0.4490, 0.6399, 0.5469, 0.5155, 0.5339, 0.5860, 0.6092, 0.4000, 0.4622,
        0.4235, 0.5554, 0.4088, 0.5798, 0.5034, 0.4752, 0.4337, 0.4786, 0.5766,
        0.4569, 0.5401, 0.4903, 0.4243, 0.3825, 0.6652, 0.4780, 0.5335, 0.4415,
        0.5478, 0.3797, 1.0000, 0.6133, 0.5824, 0.4292, 0.5182, 0.3953, 0.5071,
        0.5131, 0.4735, 1.0000, 0.3457, 0.5933, 0.5329])
kmpatx3s

kmpatx3s1#

正如我所提到的,减少最后一层的偏差是减少平均输出的一种方法。下面的脚本尝试了几个不同的偏差值,并绘制了对输出分布的影响。

import torch
from torch import nn
import matplotlib.pyplot as plt

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        #
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_features=in_feat, out_features=out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        # now we can use this function like below:
        self.model = nn.Sequential(*block(input_dim_generator, 500, normalize=False),
                                   *block(500, 350),
                                   *block(350, 256),
                                   nn.Linear(256, 564),
                                   nn.Sigmoid())

    # forward
    def forward(self, old_vector, z):
        vector_app = torch.cat((old_vector, z), dim=1)
        new_vector = self.model(vector_app)
        new_result = torch.max(new_vector, old_vector).float()
        return new_result

z = torch.rand(160 , 70)
input_dim_generator = 634
old_vector = torch.randint(2,(160, 564))

gen = Generator()

bias_shift = 0
bias_delta = -.5

fig, ax = plt.subplots(2,2, figsize = (8,8))

with torch.no_grad():
    for i in range(4):
        ax_cur = ax[i//2][i%2]
        ax_cur.hist(gen(old_vector, z).numpy().ravel(),bins=20)
        ax_cur.set_title(f"{bias_shift} added to bias")
        bias_shift += bias_delta
        gen.model[-2].bias += bias_delta

生成的直方图:

我们可以看到,减少偏差会影响输出的分布,但似乎不会影响正好等于1的输出的数量(我怀疑它直接来自旧向量)。

相关问题