如何使用pytorch高效地计算大型数据集中每个示例的梯度?

kse8i1jr  于 2023-05-07  发布在  其他
关注(0)|答案(2)|浏览(147)

给定一个训练好的模型(M),我感兴趣的是计算池中新的(看不见的)例子的效用(对于一个主动学习任务)。为此,我需要计算在每个新样本上训练M时梯度的大小。在代码中,它类似于:

losses, grads = [], []
for i in range(X_pool.shape[0]):
    pred = model(X_pool[i:i+1])
    loss = loss_func(pred, y_pool[i:i+1])

    model.zero_grad()
    loss.backward()

    losses.append(loss)
    grads.append(layer.weight.grad.norm())

然而,当有大量的示例时,这是相当慢的,特别是因为这将是我的场景中的内部循环。在pytorch中有没有更有效的方法来做到这一点?

tyu7yeag

tyu7yeag1#

根据代码,看起来你只看到了模型中一层的渐变。您可以将此层拆分为多个副本,每个副本仅采用批处理的一个组件。这样,仅针对该特定样本计算梯度,但在其他任何地方使用批处理。
这里有一个完整的例子,比较你的方法(方法1)和我提出的方法(方法2)。这应该很容易扩展到更复杂的网络。

import torch
import torch.nn as nn
import copy

batch_size = 50
num_classes = 10

class SimpleModel(nn.Module):
    def __init__(self, num_classes):
        super(SimpleModel, self).__init__()
        # input 3x10x10
        self.conv1 = nn.Conv2d(3, 10, kernel_size=3, padding=1, bias=False)
        # 10x10x10
        self.conv2 = nn.Conv2d(10, 20, kernel_size=3, stride=2, padding=1, bias=False)
        # 20x5x5
        self.fc = nn.Linear(20*5*5, num_classes, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.shape[0], -1)
        return self.fc(x)

def method1(model, X_pool, y_pool):
    loss_func = nn.CrossEntropyLoss()
    layer = model.conv2

    losses, grads = [], []
    for i in range(X_pool.shape[0]):
        pred = model(X_pool[i:i+1])
        loss = loss_func(pred, y_pool[i:i+1])

        model.zero_grad()
        loss.backward()

        losses.append(loss)
        grads.append(layer.weight.grad.norm())
    return losses, grads

def method2(model, X_pool, y_pool):
    class Replicated(nn.Module):
        """ Instead of running a batch through one layer, run individuals through copies of layer """
        def __init__(self, layer, batch_size):
            super(Replicated, self).__init__()
            self.batch_size = batch_size
            self.layers = [copy.deepcopy(layer) for _ in range(batch_size)]

        def forward(self, x):
            assert x.shape[0] <= self.batch_size
            return torch.stack([self.layers[idx](x[idx:idx+1, :]) for idx in range(x.shape[0])])

    # compute individual loss functions so we can return them
    loss_func = nn.CrossEntropyLoss(reduction='none')

    # replace layer in model with replicated layer
    layer = model.conv2
    model.conv2 = Replicated(layer, batch_size)
    layers = model.conv2.layers

    # batch of predictions
    pred = model(X_pool)
    losses = loss_func(pred, y_pool)
    # reduce with sum so that the individual loss terms aren't scaled (like with mean) which would also scale the gradients
    loss = torch.sum(losses)
    model.zero_grad()
    loss.backward()
    # gradients of each layer scaled by batch_size to match original
    grads = [layers[idx].weight.grad.norm() for idx in range(X_pool.shape[0])]

    # convert to list of tensors to match method1 output
    losses = [l for l in losses]

    # put original layer back
    model.conv2 = layer
    return losses, grads

model = SimpleModel(num_classes)
X_pool = torch.rand(batch_size, 3, 10, 10)
y_pool = torch.randint(0, num_classes, (batch_size,))

losses2, grads2 = method2(model, X_pool, y_pool)
losses1, grads1 = method1(model, X_pool, y_pool)

print("Losses Diff:", sum([abs(l1.item()-l2.item()) for l1,l2 in zip(losses1, losses2)]))
print("Grads Diff:", sum([abs(g1.item()-g2.item()) for g1,g2 in zip(grads1, grads2)]))

两种算法之间的数值差异只是浮点误差。

Losses Diff: 3.337860107421875e-06
Grads Diff: 1.9431114196777344e-05

我还没有在更大的网络中进行过测试,但我使用了batch_size并在网络中运行了多个批处理,在这个简单的模型中看到了2- 3倍的加速。在一个更复杂的模型中,它应该更重要,因为除了被复制的层之外,您可以在所有层上获得批处理的性能优势。

警告这可能不适用于DataParallel

3pvhb19x

3pvhb19x2#

你可以使用torch.func.vmap来实现。它是版本2中的一个新实用程序,它通过在函数上引入批处理维度来向量化函数。在你的例子中,你可以编写一个纯函数来计算单个例子的梯度,然后使用vmap将其向量化为一批例子:

from torch.func import grad, vmap, functional_call

# A pure function computing the loss of the model
def cost_model(in_, target, params, buffers):
    # Compute the cost for a single instance using pure functions
    in_ = in_.unsqueeze(0)
    target = target.unsqueeze(0)

    out = functional_call(model, (params, buffers), in_)
    cost = loss_fn(out, target)

    return cost
# Use torch.func.grad to obtain a function giving the gradient of the loss
ft_grad = grad(cost_model, argnums=2)

# Use torch.func.vmap to vectorize gradient calculation
ft_all_grads = vmap(ft_grad, in_dims = (0, 0, None, None))

我写的代码灵感来自Pytorch教程here。这是一个矢量化的代码,因此预计它的工作速度比手动迭代每个训练示例要快得多。

相关问题