我看到,当使用enumerate()
在my Dataloader()
对象上循环时,我得到了一个强制的新维度,以便创建我的数据批。
我有4个Tensor,我在宏观层面上进行切片(我是面板数据,所以我将数据切片为个体块,而不是行(或观察结果)):
X
(三维)Y
(二维)Z
(二维)id
(二维)。
在数据中,我有10个观测值,但在我的数据集中,样本上只有5个个体(因此,每个个体有2个观测值)。因此,我的数据中的每个批次最少有两个观测值。
由于我正在设置batch_size = 2
,所以我对第一批和第二批进行了4次观测,对第三批只进行了2次观测。
此行为在下面的输出中表示:
Selection of the data for by __getitem__ for individual 1
torch.Size([2, 3, 3]) X_batch when selecting for ind 1
torch.Size([2, 3]) Z_batch when selecting for ind 1
torch.Size([2, 1]) Y_batch when selecting for ind 1
Selection of the data for by __getitem__ for individual 2
torch.Size([2, 3, 3]) X_batch when selecting for ind 2
torch.Size([2, 3]) Z_batch when selecting for ind 2
torch.Size([2, 1]) Y_batch when selecting for ind 2
Data of the Batch # 1 inside the enumerate
shape X (outside foo) torch.Size([2, 2, 3, 3]) # <<-- here I have a new dimension
shape Z (outside foo) torch.Size([2, 2, 3])
shape Y (outside foo) torch.Size([2, 2, 1])
Selection of the data for by __getitem__ for individual 3
torch.Size([2, 3, 3]) X_batch when selecting for ind 3
torch.Size([2, 3]) Z_batch when selecting for ind 3
torch.Size([2, 1]) Y_batch when selecting for ind 3
Selection of the data for by __getitem__ for individual 4
torch.Size([2, 3, 3]) X_batch when selecting for ind 4
torch.Size([2, 3]) Z_batch when selecting for ind 4
torch.Size([2, 1]) Y_batch when selecting for ind 4
Data of the Batch # 2 inside the enumerate
shape X (outside foo) torch.Size([2, 2, 3, 3]) # <<-- here I have a new dimension
shape Z (outside foo) torch.Size([2, 2, 3])
shape Y (outside foo) torch.Size([2, 2, 1])
Selection of the data for by __getitem__ for individual 5
torch.Size([2, 3, 3]) X_batch when selecting for ind 5
torch.Size([2, 3]) Z_batch when selecting for ind 5
torch.Size([2, 1]) Y_batch when selecting for ind 5
Data of the Batch # 3 inside the enumerate
shape X (outside foo) torch.Size([1, 2, 3, 3]) # <<-- here I have a new dimension
shape Z (outside foo) torch.Size([1, 2, 3])
shape Y (outside foo) torch.Size([1, 2, 1])
首先,我选择了对应于第一个和第二个个体的数据,但是在enumerate()
循环中,我得到了一个新的维度([0]
),python使用它来放置if个体块。
所以我的问题是:
- 为了存储整批数据,是否有任何方法可以将
torch.cat(, axis = 0)
数据块连接起来,而不是创建这个新维度?**
- 为了存储整批数据,是否有任何方法可以将
例如,对于第一个人,我需要以下内容
Data of the Batch # 1 inside the enumerate
shape X (outside foo) torch.Size([4, 3, 3]) # <<-- here I torch.concat(,axis = 0)
shape Z (outside foo) torch.Size([4, 3])
shape Y (outside foo) torch.Size([4, 1])
生成下面输出的代码列在最后。谢谢
样品数据
import torch
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import argparse
# args to be passed to the model
parser = argparse.ArgumentParser(description='Neural network for Flexible utility (VOT =f(z))')
args = parser.parse_args("")
args.J = 3 # number of alternatives
# Sample data
X = pd.DataFrame.from_dict({'x1_1': {0: -0.1766214634108258, 1: 1.645852185286492, 2: -0.13348860101031038, 3: 1.9681043689968933, 4: -1.7004428240831382, 5: 1.4580091413853749, 6: 0.06504113741068565, 7: -1.2168493676768384, 8: -0.3071304478616376, 9: 0.07121332925591593}, 'x1_2': {0: -2.4207773498298844, 1: -1.0828751040719462, 2: 2.73533787008624, 3: 1.5979611987152071, 4: 0.08835542172064115, 5: 1.2209786277076156, 6: -0.44205979195950784, 7: -0.692872860268244, 8: 0.0375521181289943, 9: 0.4656030062266639}, 'x1_3': {0: -1.548320898226322, 1: 0.8457342014424675, 2: -0.21250514722879738, 3: 0.5292389938329516, 4: -2.593946520223666, 5: -0.6188958526077123, 6: 1.6949245117526974, 7: -1.0271341091035742, 8: 0.637561891142571, 9: -0.7717170035055559}, 'x2_1': {0: 0.3797245517345564, 1: -2.2364391598508835, 2: 0.6205947900678905, 3: 0.6623865847688559, 4: 1.562036259999875, 5: -0.13081282910947759, 6: 0.03914373833251773, 7: -0.995761652421108, 8: 1.0649494418154162, 9: 1.3744782478849122}, 'x2_2': {0: -0.5052556836786106, 1: 1.1464291788297152, 2: -0.5662380273138174, 3: 0.6875729143723538, 4: 0.04653136473130827, 5: -0.012885303852347407, 6: 1.5893672346098884, 7: 0.5464286050059511, 8: -0.10430829457707284, 9: -0.5441755265313813}, 'x2_3': {0: -0.9762973303149007, 1: -0.983731467806563, 2: 1.465827578266328, 3: 0.5325950414202745, 4: -1.4452121324204903, 5: 0.8148816373643869, 6: 0.470791989780882, 7: -0.17951636294180473, 8: 0.7351814781280054, 9: -0.28776723200679066}, 'x3_1': {0: 0.12751822396637064, 1: -0.21926633684030983, 2: 0.15758799357206943, 3: 0.5885412224632464, 4: 0.11916562911189271, 5: -1.6436210334529249, 6: -0.12444368631987467, 7: 1.4618564171802453, 8: 0.6847234328916137, 9: -0.23177118858569187}, 'x3_2': {0: -0.6452955690715819, 1: 1.052094761527654, 2: 0.20190339195326157, 3: 0.6839430295237913, 4: -0.2607691613858866, 5: 0.3315513026670213, 6: 0.015901139336566113, 7: 0.15243420084881903, 8: -0.7604225072161022, 9: -0.4387652927008854}, 'x3_3': {0: -1.067058994377549, 1: 0.8026914180717286, 2: -1.9868531745912268, 3: -0.5057770735303253, 4: -1.6589569342151713, 5: 0.358172252880764, 6: 1.9238983803281329, 7: 2.2518318810978246, 8: -1.2781475121874357, 9: -0.7103081175166167}})
Y = pd.DataFrame.from_dict({'CHOICE': {0: 1.0, 1: 1.0, 2: 2.0, 3: 2.0, 4: 3.0, 5: 2.0, 6: 1.0, 7: 1.0, 8: 2.0, 9: 2.0}})
Z = pd.DataFrame.from_dict({'z1': {0: 2.4196730570917233, 1: 2.4196730570917233, 2: 2.822802255159467, 3: 2.822802255159467, 4: 2.073171091633643, 5: 2.073171091633643, 6: 2.044165101485163, 7: 2.044165101485163, 8: 2.4001241292606275, 9: 2.4001241292606275}, 'z2': {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.0, 5: 1.0, 6: 1.0, 7: 1.0, 8: 0.0, 9: 0.0}, 'z3': {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 2.0, 5: 2.0, 6: 2.0, 7: 2.0, 8: 3.0, 9: 3.0}})
id = pd.DataFrame.from_dict({'id_choice': {0: 1.0, 1: 2.0, 2: 3.0, 3: 4.0, 4: 5.0, 5: 6.0, 6: 7.0, 7: 8.0, 8: 9.0, 9: 10.0}, 'id_ind': {0: 1.0, 1: 1.0, 2: 2.0, 3: 2.0, 4: 3.0, 5: 3.0, 6: 4.0, 7: 4.0, 8: 5.0, 9: 5.0}} )
# Create a dataframe with all the data
data = pd.concat([id,X, Z, Y], axis=1)
定义torch.utils.data.Dataset()
# class to create a dataset for choice data
class ChoiceDataset_all(Dataset):
'''
Dataset for choice data
Args:
data (pandas dataframe): dataframe with all the data
Returns:
dictionary with the data for each individual
'''
def __init__(self, data, args , id_variable:str = "id_ind" ):
if id_variable not in data.columns:
raise ValueError(f"Variable {id_variable} not in dataframe")
self.data = data
# select cluster variable
self.cluster_ids = self.data[id_variable].unique()
self.Y = torch.LongTensor(self.data['CHOICE'].values -1).reshape(len(self.data['CHOICE'].index),1)
self.id = torch.LongTensor(self.data[id_variable].values).reshape(len(self.data[id_variable].index),1)
# number of individuals (N_n)
self.N_n = torch.unique(self.id).shape[0]
# number of choices made per individual (t_n)
_ , self.t_n = self.id.unique(return_counts=True)
#total number of observations (N_t = total number of choices)
self.N_t = self.t_n.sum(axis=0).item()
# Select regressors: variables that start with "x"
self.X_wide = data.filter(regex='^x')
# turn X_wide into a tensor
self.X = torch.DoubleTensor(self.X_wide.values)
# number of regressors (K)
self.K = int(self.X_wide.shape[1] / args.J)
# reshape X to have the right dimensions
# Select variables that start with "z"
self.Z = torch.DoubleTensor(self.data.filter(regex='^z').values)
def __len__(self):
return self.N_n # number of individuals
def __getitem__(self, idx):
# select the index of the individual
self.index = torch.where(self.id == idx+1)[0]
self.len_batch = self.index.shape[0]
# Select observations for the individual
Y_batch = self.Y[self.index]
Z_batch = self.Z[self.index]
id_batch = self.id[self.index]
X_batch = self.X[self.index]
# reshape X_batch to have the right dimensions
X_batch = X_batch.reshape(self.len_batch,self.K,args.J)
print("\n")
print("Selection of the data for by __getitem__ for individual", idx+1)
print(X_batch.shape, "X_batch when selecting for ind", idx+1)
print(Z_batch.shape, "Z_batch when selecting for ind", idx+1)
print(Y_batch.shape, "Y_batch when selecting for ind", idx+1)
#print(id_batch.shape, "id_batch when selecting for ind", idx+1)
return {'X': X_batch, 'Z': Z_batch, 'Y': Y_batch, 'id': id_batch}
在torch.utils.data.DataLoader()
上循环
choice_data = ChoiceDataset_all(data, args, id_variable="id_ind")
data_loader = DataLoader(choice_data, batch_size=2, shuffle=False, num_workers=0, drop_last=False)
for idx, data_dict in enumerate(data_loader):
print("\n")
print("Data of the Batch # ", idx+1, "inside the enumerate")
print("shape X (outside foo)", data_dict['X'].shape)
print("shape Z (outside foo)", data_dict['Z'].shape)
print("shape Y (outside foo)", data_dict['Y'].shape)
# print("shape id (outside foo)", data_dict['id'])
1条答案
按热度按时间iklwldmw1#
多亏了@jorad,我发现了
collate_fn
的美妙之处。这里有一个解决问题的方法。编辑:另一个建议是在Official PyTorch forum,和我的profileD都在下面。TLDR;一个简单的
view
while循环比实现我的定制collate_fn
要快。详细回答如下:
剖析