我需要写一个函数来嵌入一个RNN数据集,它将输入作为Tensor并将其转换为单词。我知道这个函数具有下面的结构,但不知道如何声明参数。
def rnn_embedder(tensor, embedding_length):
'''
# Takes a tensor and a vocabulary and returns the BoW embedding of that tensor
# Args:
tensor (torch.Tensor): A tensor of words represented by their index in the vocabulary
vocab_lenght (int): The number of entries in the vocabulary
Returns (torch.Tensor): An tensor containing the BoW embedding of the input tensor
'''
tensor = tensor.long()
embedding = ...
words = ...
for ...
return numpy.asarray(embedding)
1条答案
按热度按时间xuo3flqw1#
要声明rnn_embedder函数的参数,您需要指定函数将接受的输入的类型和名称。
第一个参数,Tensor,是一个由词汇表中的索引表示的单词Tensor。这个参数应该是一个 Torch 。Tensor对象。
第二个参数embedding_length是一个整数,表示词汇表中的条目数,这个参数应该是一个int对象。
使用此信息,可以按如下所示声明rnn_embedder函数的参数:
函数定义指定rnn_embedder函数采用两个参数:turch.tensor类型的Tensor和int类型的embedding_length,并返回turch.tensor对象。