pytorch 自定义Multihead Attention类泄漏因果注意力的数据

laik7k3q  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(83)

我一直致力于在PyTorch中为Transformer模型实现一个自定义的多头注意力类,以用于学习目的。我的实现缺乏任何功能,我只是想让它为一个基本案例工作。我注意到,对于因果注意(令牌不能关注未来的令牌),我的模型似乎遭受了数据泄漏。我是在用torch nn.MultiheadAttention类测试同一脚本后得出这个结论的。
对我来说,似乎问题出在我敷面膜的方式上,但我真的找不到问题所在。我已经测试了一个二维掩码可以正确地传播到四维Tensor(这是我的方法)。我已经验证了几次,正确的令牌被掩盖,无济于事。
这是密码

class MultiHeadAttention(nn.Module):

    def __init__(self, n_heads, d_model,  dropout=0.1):

        super().__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.query = nn.Linear(d_model, d_model, bias=False)
        self.key = nn.Linear(d_model, d_model, bias=False)
        self.value = nn.Linear(d_model, d_model, bias=False)
        self.att_proj = nn.Linear(d_model, d_model, bias=False)
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())

    def forward(self, x):

        q = x
        k = x
        v = x
        B,T,C = x.shape 
        dk = d_model // n_heads

        # linear projections
        q = self.query(q) 
        k = self.key(k) 
        v = self.value(v) 

        # add number of heads
        q = q.view(B,T,n_heads,dk).permute(0,2,1,3)   # B,T,h,dk
        k = k.view(B,T,n_heads,dk).permute(0,2,1,3)  
        v = v.view(B,T,n_heads,dk).permute(0,2,1,3)  
        
        # attention 

        x = q @ k.transpose(-2,-1) # B,h,T,dk @ B,h,dk,T --> B,h,T,T
        x = x * dk ** -0.5 # B,h,T,T
        x = x.masked_fill(self.mask, float('-inf')) # B,h,T,T
        x = F.softmax(x, dim=(-1)) # B,n_h,T,T 
        x = x @ v  # B,h,T,T @ B,T,h,dv --> B,h,T,dv
        x = x.view(B,T,-1)
        out = self.att_proj(x) # B,T,C

        return out```

With a toy example I quickly get to Losses such as Training Loss: 2.307. Evaluation Loss: 2.278. When using  torch the losses are far less ambitious Iteration 9999. Training Loss: 2.469. Evaluation Loss: 2.483. What am I missing?

This is my model implementation just in case the error is here

字符串
class Model(nn.Module):

def __init__(self, vocab_size, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)

    self.embedding_table = nn.Embedding(vocab_size, d_model)
    self.mha = MultiHeadAttention(n_heads, d_model)
    self.out = nn.Linear(d_model, vocab_size, bias=False)

def forward(self, x, targets=None):

    x = self.embedding_table(x)
    B, T, C = x.shape
    
    x = self.mha(x) # B,T,C
    logits = self.out(x) # B,T,vocab_size

    if targets is not None:
        logits = logits.reshape(-1, logits.shape[-1])
        targets = targets.reshape(-1)
        loss = F.cross_entropy(input=logits, target=targets)
    else:
        loss = None

    return logits, loss

def generate(self, n_chars, ix):

    for _ in range(n_chars):

        logits, loss = self(ix) # B, T, C
        logits = logits[:,-1,:] # B, C -- we need to reshape to calculate probabilities
        probs = F.softmax(logits, dim=-1) # B, C
        next_ix = torch.multinomial(input=probs, num_samples=1)
        ix = torch.cat((ix, next_ix), dim=1)

    return ix```


我尝试使用不同的训练和验证分割方法,以确保这里没有发生泄漏。然后,我尝试了几种掩蔽方法,使用tril并用-inf填充0或triu用-inf填充True。我已经确保对角线为1,以便只有未来的令牌被屏蔽

nkoocmlb

nkoocmlb1#

我已经初步找到了问题所在。我用错误的方式重塑了一个中间结果
我不能在v被计算之后做

x = x.view(B,T,-1)

字符串
相反我应该做的

B,h,T,dv = x.shape
x = x.transpose(2,1).contiguous().view(B,T,h*dv) #B,T,C

相关问题