将PyTorchTensor中的元素约束为相等

zf9nrax1  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(165)

我有一个PyTorchTensor,希望在优化对其元素施加相等约束。下面显示了一个2 * 9Tensor的示例,其中相同的颜色表示元素应始终相等。

让我们举一个1 * 4的最小例子,并分别初始化前两个和后两个元素为相等。

import torch
x1 = torch.tensor([1.2, 1.2, -0.3, -0.3])
print(x1)
# tensor([ 1.2000,  1.2000, -0.3000, -0.3000])

如果我直接做一个简单的最小二乘法,等式肯定不存在了。

y = torch.arange(4)
opt_1 = torch.optim.SGD([x1], lr=0.1)
opt_1.zero_grad()
loss = (y - x1).pow(2).sum()
loss.backward()
opt_1.step()
print(x1)
# tensor([0.9600, 1.1600, 0.1600, 0.3600], requires_grad=True)

我试着把这个Tensor表示为,掩码的加权和。

def weighted_sum(c, masks):
    return torch.sum(torch.stack([c[0] * masks[0], c[1] * masks[1]]), axis=0)

c = torch.tensor([1.2, -0.3], requires_grad=True)
masks = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1]])
x2 = weighted_sum(c, masks)
print(x2)
# tensor([ 1.2000,  1.2000, -0.3000, -0.3000])

这样,在优化之后保持相等。

opt_c = torch.optim.SGD([c], lr=0.1)
opt_c.zero_grad()
y = torch.arange(4)
x2 = weighted_sum(c, masks)
loss = (y - x2).pow(2).sum()
loss.backward()
opt_c.step()
print(c)
# tensor([0.9200, 0.8200], requires_grad=True)
print(weighted_sum(c, masks))
# tensor([0.9200, 0.9200, 0.8200, 0.8200], grad_fn=<SumBackward1>)

然而,这个解决方案的最大问题是,当输入维度很高时,我必须维护一个大的掩码集合;当然这将导致内存不足。假设输入Tensor的形状是d_0 * d_1 * ... * d_m,等式块的数目是k,则将存在形状为k * d_0 * d_1 * ... * d_m的巨大掩模,这是不可接受的。
另一种解决方案可以是上采样低分辨率Tensor,如this one。然而,它不能应用于不规则等式块,例如,

tensor([[ 1.2000,  1.2000,  1.2000, -3.1000, -3.1000],
        [-0.1000,  2.0000,  2.0000,  2.0000,  2.0000]])

那么......有没有更聪明的方法来实现PyTorchTensor中的等式约束?

63lcw9qa

63lcw9qa1#

如果希望它们始终相等,为什么不删除xy中的第一个和最后一个值呢?额外的值可以在训练后根据需要从模型输出中导出,因为它们无论如何都应该与相邻值相等。不需要学习相同值的两个副本。
如果你想更近似地知道它们是相同的,你可以把some_weight * (torch.abs(x[0]-x[1]) + torch.abs(x[-1] - x[-2]))加到损失函数中,那么你的损失就是试图知道它们是相同的。

相关问题