我一直致力于在PyTorch中为Transformer模型实现一个自定义的多头注意力类,以用于学习目的。我的实现缺乏任何功能,我只是想让它为一个基本案例工作。我注意到,对于因果注意(令牌不能关注未来的令牌),我的模型似乎遭受了数据泄漏。我是在用torch nn.MultiheadAttention类测试同一脚本后得出这个结论的。
对我来说,似乎问题出在我敷面膜的方式上,但我真的找不到问题所在。我已经测试了一个二维掩码可以正确地传播到四维Tensor(这是我的方法)。我已经验证了几次,正确的令牌被掩盖,无济于事。
这是密码
class MultiHeadAttention(nn.Module):
def __init__(self, n_heads, d_model, dropout=0.1):
super().__init__()
self.n_heads = n_heads
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
self.query = nn.Linear(d_model, d_model, bias=False)
self.key = nn.Linear(d_model, d_model, bias=False)
self.value = nn.Linear(d_model, d_model, bias=False)
self.att_proj = nn.Linear(d_model, d_model, bias=False)
self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())
def forward(self, x):
q = x
k = x
v = x
B,T,C = x.shape
dk = d_model // n_heads
# linear projections
q = self.query(q)
k = self.key(k)
v = self.value(v)
# add number of heads
q = q.view(B,T,n_heads,dk).permute(0,2,1,3) # B,T,h,dk
k = k.view(B,T,n_heads,dk).permute(0,2,1,3)
v = v.view(B,T,n_heads,dk).permute(0,2,1,3)
# attention
x = q @ k.transpose(-2,-1) # B,h,T,dk @ B,h,dk,T --> B,h,T,T
x = x * dk ** -0.5 # B,h,T,T
x = x.masked_fill(self.mask, float('-inf')) # B,h,T,T
x = F.softmax(x, dim=(-1)) # B,n_h,T,T
x = x @ v # B,h,T,T @ B,T,h,dv --> B,h,T,dv
x = x.view(B,T,-1)
out = self.att_proj(x) # B,T,C
return out```
With a toy example I quickly get to Losses such as Training Loss: 2.307. Evaluation Loss: 2.278. When using torch the losses are far less ambitious Iteration 9999. Training Loss: 2.469. Evaluation Loss: 2.483. What am I missing?
This is my model implementation just in case the error is here
字符串
class Model(nn.Module):
def __init__(self, vocab_size, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.embedding_table = nn.Embedding(vocab_size, d_model)
self.mha = MultiHeadAttention(n_heads, d_model)
self.out = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, x, targets=None):
x = self.embedding_table(x)
B, T, C = x.shape
x = self.mha(x) # B,T,C
logits = self.out(x) # B,T,vocab_size
if targets is not None:
logits = logits.reshape(-1, logits.shape[-1])
targets = targets.reshape(-1)
loss = F.cross_entropy(input=logits, target=targets)
else:
loss = None
return logits, loss
def generate(self, n_chars, ix):
for _ in range(n_chars):
logits, loss = self(ix) # B, T, C
logits = logits[:,-1,:] # B, C -- we need to reshape to calculate probabilities
probs = F.softmax(logits, dim=-1) # B, C
next_ix = torch.multinomial(input=probs, num_samples=1)
ix = torch.cat((ix, next_ix), dim=1)
return ix```
型
我尝试使用不同的训练和验证分割方法,以确保这里没有发生泄漏。然后,我尝试了几种掩蔽方法,使用tril并用-inf填充0或triu用-inf填充True。我已经确保对角线为1,以便只有未来的令牌被屏蔽
1条答案
按热度按时间nkoocmlb1#
我已经初步找到了问题所在。我用错误的方式重塑了一个中间结果
我不能在v被计算之后做
字符串
相反我应该做的
型