在注意到my custom implementation of first order MAML可能是错误的之后,我决定在谷歌上搜索一下一阶MAML的官方方法。我发现了一个有用的gitissue,它建议停止跟踪高阶梯度。这对我来说完全有意义。没有更多的导数。但是当我试图将其设置为false时(这样就不会跟踪更高阶的导数了)我得到了我的模型没有更多的训练,.grad
字段是None
。这显然是错误的。
难道这是虫子在高等还是怎么回事?
要重现运行official MAML example higher,请执行slightly modified here。主要代码如下:
# !/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This example shows how to use higher to do Model Agnostic Meta Learning (MAML)
for few-shot Omniglot classification.
For more details see the original MAML paper:
https://arxiv.org/abs/1703.03400
This code has been modified from Jackie Loong's PyTorch MAML implementation:
https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot_train.py
Our MAML++ fork and experiments are available at:
https://github.com/bamos/HowToTrainYourMAMLPytorch
"""
import argparse
import time
import typing
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
plt.style.use('bmh')
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import higher
from support.omniglot_loaders import OmniglotNShot
def main():
argparser = argparse.ArgumentParser()
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument(
'--k_spt', type=int, help='k shot for support set', default=5)
argparser.add_argument(
'--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument(
'--task_num',
type=int,
help='meta batch size, namely task num',
default=32)
argparser.add_argument('--seed', type=int, help='random seed', default=1)
args = argparser.parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Set up the Omniglot loader.
# device = torch.device('cuda')
# from uutils.torch_uu import get_device
# device = get_device()
device = torch.device(f"cuda:{gpu_idx}" if torch.cuda.is_available() else "cpu")
db = OmniglotNShot(
'/tmp/omniglot-data',
batchsz=args.task_num,
n_way=args.n_way,
k_shot=args.k_spt,
k_query=args.k_qry,
imgsz=28,
device=device,
)
# Create a vanilla PyTorch neural network that will be
# automatically monkey-patched by higher later.
# Before higher, models could *not* be created like this
# and the parameters needed to be manually updated and copied
# for the updates.
net = nn.Sequential(
nn.Conv2d(1, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 64, 3),
nn.BatchNorm2d(64, momentum=1, affine=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
Flatten(),
nn.Linear(64, args.n_way)).to(device)
# We will use Adam to (meta-)optimize the initial parameters
# to be adapted.
meta_opt = optim.Adam(net.parameters(), lr=1e-3)
log = []
for epoch in range(100):
train(db, net, device, meta_opt, epoch, log)
test(db, net, device, epoch, log)
# plot(log)
def train(db, net, device, meta_opt, epoch, log):
net.train()
n_train_iter = db.x_train.shape[0] // db.batchsz
for batch_idx in range(n_train_iter):
start_time = time.time()
# Sample a batch of support and query images and labels.
x_spt, y_spt, x_qry, y_qry = db.next()
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# TODO: Maybe pull this out into a separate module so it
# doesn't have to be duplicated between `train` and `test`?
# Initialize the inner optimizer to adapt the parameters to
# the support set.
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
qry_losses = []
qry_accs = []
meta_opt.zero_grad()
for i in range(task_num):
with higher.innerloop_ctx(
net, inner_opt, copy_initial_weights=False,
# track_higher_grads=True,
track_higher_grads=False,
) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
# higher is able to automatically keep copies of
# your network's parameters as they are being updated.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The final set of adapted parameters will induce some
# final loss and accuracy on the query dataset.
# These will be used to update the model's meta-parameters.
qry_logits = fnet(x_qry[i])
qry_loss = F.cross_entropy(qry_logits, y_qry[i])
qry_losses.append(qry_loss.detach())
qry_acc = (qry_logits.argmax(
dim=1) == y_qry[i]).sum().item() / querysz
qry_accs.append(qry_acc)
# Update the model's meta-parameters to optimize the query
# losses across all of the tasks sampled in this batch.
# This unrolls through the gradient steps.
qry_loss.backward()
assert meta_opt.param_groups[0]['params'][0].grad is not None
meta_opt.step()
qry_losses = sum(qry_losses) / task_num
qry_accs = 100. * sum(qry_accs) / task_num
i = epoch + float(batch_idx) / n_train_iter
iter_time = time.time() - start_time
if batch_idx % 4 == 0:
print(
f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
)
log.append({
'epoch': i,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'train',
'time': time.time(),
})
def test(db, net, device, epoch, log):
# Crucially in our testing procedure here, we do *not* fine-tune
# the model during testing for simplicity.
# Most research papers using MAML for this task do an extra
# stage of fine-tuning here that should be added if you are
# adapting this code for research.
net.train()
n_test_iter = db.x_test.shape[0] // db.batchsz
qry_losses = []
qry_accs = []
for batch_idx in range(n_test_iter):
x_spt, y_spt, x_qry, y_qry = db.next('test')
task_num, setsz, c_, h, w = x_spt.size()
querysz = x_qry.size(1)
# doesn't have to be duplicated between `train` and `test`?
n_inner_iter = 5
inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)
for i in range(task_num):
with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):
# Optimize the likelihood of the support set by taking
# gradient steps w.r.t. the model's parameters.
# This adapts the model's meta-parameters to the task.
for _ in range(n_inner_iter):
spt_logits = fnet(x_spt[i])
spt_loss = F.cross_entropy(spt_logits, y_spt[i])
diffopt.step(spt_loss)
# The query loss and acc induced by these parameters.
qry_logits = fnet(x_qry[i]).detach()
qry_loss = F.cross_entropy(
qry_logits, y_qry[i], reduction='none')
qry_losses.append(qry_loss.detach())
qry_accs.append(
(qry_logits.argmax(dim=1) == y_qry[i]).detach())
qry_losses = torch.cat(qry_losses).mean().item()
qry_accs = 100. * torch.cat(qry_accs).float().mean().item()
print(
f'[Epoch {epoch + 1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'
)
log.append({
'epoch': epoch + 1,
'loss': qry_losses,
'acc': qry_accs,
'mode': 'test',
'time': time.time(),
})
def plot(log):
# Generally you should pull your plotting code out of your training
# script but we are doing it here for brevity.
df = pd.DataFrame(log)
fig, ax = plt.subplots(figsize=(6, 4))
train_df = df[df['mode'] == 'train']
test_df = df[df['mode'] == 'test']
ax.plot(train_df['epoch'], train_df['acc'], label='Train')
ax.plot(test_df['epoch'], test_df['acc'], label='Test')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_ylim(70, 100)
fig.legend(ncol=2, loc='lower right')
fig.tight_layout()
fname = 'maml-accs.png'
print(f'--- Plotting accuracy to {fname}')
fig.savefig(fname)
plt.close(fig)
# Won't need this after this PR is merged in:
# https://github.com/pytorch/pytorch/pull/22245
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
if __name__ == '__main__':
main()
注意事项:
我在这里问了一个类似的问题Would making the gradient "data" by detaching them implement first order MAML using PyTorch's higher library?,但这个问题稍有不同。它问的是一个自定义实现,它直接分离梯度,使它们成为“数据”。这个问题问的是为什么设置track_higher_grads=False
会破坏梯度的填充--据我所知,这不应该。
相关产品:
- 错误报告,因为从讨论我期望的标志,以解决问题:https://github.com/facebookresearch/higher/issues/129
- https://github.com/facebookresearch/higher/issues?q=is%3Aissue+first+order+maml+is%3Aclosed
- https://github.com/facebookresearch/higher/issues/63
- https://github.com/facebookresearch/higher/issues/128
- https://www.reddit.com/r/pytorch/comments/sixdqd/what_is_the_official_implementation_of_first/
- https://www.reddit.com/r/pytorch/comments/si5xv1/would_making_the_gradient_data_by_detaching_them/
赏金
解释解决方案here工作的原因,即为什么
track_higher_grads = True
...
diffopt.step(inner_loss, grad_callback=lambda grads: [g.detach() for g in grads])
计算FO maml,但:
new_params = params[:]
for group, mapping in zip(self.param_groups, self._group_to_param_list):
for p, index in zip(group['params'], mapping):
if self._track_higher_grads:
new_params[index] = p
else:
new_params[index] = p.detach().requires_grad_() # LIKELY THIS LINE!!!
不允许FO正常工作,并将.grads设置为None(不填充grad字段)。老实说,p.detach().requires_grad_()
的作业在我看来是一样的。这个.requires_grad_()
事件似乎格外“安全”。
2条答案
按热度按时间0yg35tkg1#
track_higher_grads=False
实际上不起作用的原因是,它分离了后适应 * 参数 * 的梯度,而不仅仅是 * 梯度 *(请看这里)。因此,您不会从外部循环损失中得到任何梯度。您真正想要的只是分离 * 内部循环计算的梯度 * 上的梯度,但是保持模型初始化和调整参数之间的(否则是微不足道的)计算图完整。iqxoj9l92#
我想我找到了解决方案,虽然很难100%确信它是正确的,因为我并不完全理解它,但我已经做了多次健全性检查,它确实改变了代码的更高和速度的行为--我假设这确实使FO工作:
1.设置
track_higher_grads = True
1.但是,使用以下grads回调调用可区分优化器:
diffopt.step(inner_loss, grad_callback=lambda grads: [g.detach() for g in grads])
个健全性检查:
do track_higher_order_grads =真,但不使用Eric的grads_callback技巧:
所以如果我再运行一次它应该会输出相同的数字。
现在🙂让我们改变种子(从0到42、142、1142),grad范数值应该改变:
现在归零:
又够近了!🙂
现在,如果eric的技巧起作用(传递一个grads回调),那么梯度值应该改变,因为它现在使用的是FO,没有更高阶的信息。因此,我将逐步改变我的代码。首先,我将保留track_higher_order_grads = True,并使用回调。这将得到这个梯度:
再运行一遍我得到(确认决定论的代码):
确认该组合做了一些不同的事情(即,他的grads_callback改变了行为)。
现在,如果我使用Eric的回调函数,但使用track_higher_order_grads=False,会怎么样:
给出了一个bug。所以设置track_higher_order_grads似乎总是错误的。
这让我觉得你的解决方案至少改变了行为,虽然我不知道为什么它的工作或为什么原来的代码由更高的不工作。
现在我将通过阅读tdqm的输出来检查代码的运行速度。如果它真的在执行FO(而不是使用更高的梯度),那么速度应该会有一些提升。在我的m1笔记本电脑上运行这个代码。下面运行的组合是track_higher_grads = True和diffopt.step(inner_loss,grad_callback=lambda grads:[g.detach()for g in grads])所以这应该是FO(较快的一个)。所以它应该比下一次运行更快地结束,具有更高的grads/hessian:
现在,在track_higher_grads = True和diffopt.step(inner_loss)的情况下,它具有更高的梯度(hessian):
因为它花了更长的时间,我会得出结论,这确实使用了黑森语&它不是格式。我认为,如果网络更大,差异会更明显(由于~平方大小的黑森语)。
可复制代码:
我真实的脚本中的一阶MAML:
现在不是FO妈妈:
FO是6天,而高阶是13天,所以它很可能是正确的!