import torch
from torch import nn
num_embeddings = 10 # fill base Age
embedding_dim = 256 # fill base of B tensor
embedding = nn.Embedding(num_embeddings, embedding_dim)
A = torch.randint(10, (17809, 6))
print(f"A : {A.shape}")
E_A = embedding(A)
print(f"E_A : {E_A.shape}")
B = torch.rand(17809, 3, 256)
print(f"B : {B.shape}")
C = torch.cat((E_A, B), 1)
print(f"C : {C.shape}")
输出量:
A : torch.Size([17809, 6])
E_A : torch.Size([17809, 6, 256])
B : torch.Size([17809, 3, 256])
C : torch.Size([17809, 9, 256])
1条答案
按热度按时间ep6jt1vc1#
您可以在
A
上应用torch.nn.Embedding
来嵌入数值向量,然后使用torch.cat
将embeding of A
和B
连接到axis=1
上。输出量: