python 在PyTorch上预热的Adam优化器

8mmmxcuj  于 2022-12-25  发布在  Python
关注(0)|答案(4)|浏览(323)

在文章Attention is all you need的第5.3节中,作者建议线性增加学习速率,然后与步长的平方根倒数成比例地降低。

我们如何在PyTorch中用Adam优化器来实现这个呢?最好不使用额外的包。

l3zydbqr

l3zydbqr1#

PyTorch提供了 learning-rate-schedulers 来实现在训练过程中调整学习速率的各种方法。一些简单的LR-schedulers已经实现了,可以在这里找到:https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
在您的特殊情况下,您可以像其他LR调度器一样,将_LRScheduler子类化,以实现基于时段数的可变调度。对于基本方法,您只需要实现__init__()get_lr()方法。
请注意,许多调度器希望您在每个时期调用.step()一次,但您也可以更频繁地更新它,甚至像余弦退火LR调度器一样传递一个自定义参数:https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#CosineAnnealingLR

mftmpeh8

mftmpeh82#

As suggested in the last comment, we can use the class introduced by https://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer. But this answer will give an error unless we define a function to update the state_dict.
下面是完整的调度程序:

class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.model_size = model_size
        self._rate = 0
    
    def state_dict(self):
        """Returns the state of the warmup scheduler as a :class:`dict`.
        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
    
    def load_state_dict(self, state_dict):
        """Loads the warmup scheduler's state.
        Arguments:
            state_dict (dict): warmup scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict) 
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))

稍后,要在训练循环中使用它:

optimizer = NoamOpt(input_opts['d_model'], 500,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
optimizer.step()
4dbbbstv

4dbbbstv3#

class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
    self.optimizer = optimizer
    self._step = 0
    self.warmup = warmup
    self.factor = factor
    self.model_size = model_size
    self._rate = 0
    
def step(self):
    "Update parameters and rate"
    self._step += 1
    rate = self.rate()
    for p in self.optimizer.param_groups:
        p['lr'] = rate
    self._rate = rate
    self.optimizer.step()
    
def rate(self, step = None):
    "Implement `lrate` above"
    if step is None:
        step = self._step
    return self.factor * \
        (self.model_size ** (-0.5) *
        min(step ** (-0.5), step * self.warmup ** (-1.5)))
    
def get_std_opt(model):
    return NoamOpt(model.src_embed[0].d_model, 2, 4000,torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

例如:https://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer

oxiaedzo

oxiaedzo4#

NoamOpt of cause提供了一种实现预热学习率的途径,如www.example.com中所示https://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer,但是它有点陈旧和不方便,实现这一点的更聪明的方法是直接使用Pytorch支持的lambda learning rate scheduler
也就是说,首先定义预热函数以自动调整学习速率,如下所示:

def warmup(current_step: int):
if current_step < args.warmup_steps:  # current_step / warmup_steps * base_lr
    return float(current_step / args.warmup_steps)
else:                                 # (num_training_steps - current_step) / (num_training_steps - warmup_steps) * base_lr
    return max(0.0, float(args.training_steps - current_step) / float(max(1, args.training_steps - args.warmup_steps)))

然后构建学习率调度程序并在训练过程中使用它:

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup)

相关问题