pytorch 将列表的列表Map到TPUTensor

dgsult0t  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(95)

我有这段代码可以将一个列表分解为不同Tensor。

token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.LongTensor, zip(*batch))

如果我想在gpu上创建这些Tensor,我可以使用下面的代码:

token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.cuda.LongTensor, zip(*batch))

但是现在我想在TPU上创建所有这些,我应该怎么做?有没有像下面这样的东西?

token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(torch.xla.LongTensor, zip(*batch))
wydwbb8l

wydwbb8l1#

您可以按照here指南使用xla设备。
您可以选择设备并将其传递给函数,如下所示:

import torch_xla.core.xla_model as xm
device = xm.xla_device()
token_a_index, token_b_index, isNext, input_ids, segment_ids, masked_tokens, masked_pos = map(lambda x: torch.Tensor(x).to(device).long(), zip(*batch))

您甚至可以参数化设备变量,torch.device("cuda" if torch.cuda.is_available() else "cpu")可用于在cudacpu之间进行选择。

相关问题