pytorch 在类中使用layers.__init__而不将它们声明为变量会产生什么结果/效果?

5q4ezhmt  于 2023-11-19  发布在  其他
关注(0)|答案(1)|浏览(123)

首先声明一下,我还不擅长Python。
下面是原始Microsoft LoRA实现的片段。

import torch
import torch.nn as nn
import torch.nn.functional as F

import math
from typing import Optional, List

class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights

class Embedding(nn.Embedding, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        r: int = 0,
        lora_alpha: int = 1,
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
                            merge_weights=merge_weights)
        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
            self.scaling = self.lora_alpha / self.r
            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
        self.reset_parameters()

    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.zeros_(self.lora_A)
            nn.init.normal_(self.lora_B)

    def train(self, mode: bool = True):
        nn.Embedding.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
                self.merged = True
        
    def forward(self, x: torch.Tensor):
        if self.r > 0 and not self.merged:
            result = nn.Embedding.forward(self, x)
            after_A = F.embedding(
                x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
                self.norm_type, self.scale_grad_by_freq, self.sparse
            )
            result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return nn.Embedding.forward(self, x)

字符串
这里有两个python问题。
1.在Embedding类中nn.Embedding.init和LoRALayer.init不作为变量的作用/结果是什么?如下所示。

nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
                            merge_weights=merge_weights)


它们不像embedding = nn.Embedding()lora = LoRALayer()。当它们不被声明为变量时,它们在Embedding类中做什么?
1.除了LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0, merge_weights=merge_weights)之外,LoRALayer类从未在Embedding类中使用或调用
那它是如何工作的,为什么会在那里?
谢谢你事先的解释。

iyr7buue

iyr7buue1#

对于第一个问题-这段代码使用python的语法进行继承,这意味着class Embedding有两个父nn.Embedding, LoRALayer,调用LoRALayer.__init__...nn.Embedding.__init__...运行父__init__,这就是为什么没有变量embedding = nn.Embedding()lora = LoRALayer()
对于问题的第二部分-class LoRALayer只有init方法,在创建class Embedding时将调用该方法,并且在此代码中为字段设置初始值,这些字段可以在class Embedding中使用
在我看来,这不是实现目标的最佳代码,但它是如何工作的

相关问题