Pytorch问题:当num_workers > 0时,我的jupyter卡住了

xa9qqrwz  于 2023-11-19  发布在  其他
关注(0)|答案(3)|浏览(201)

这是我在PyTorch中的代码片段,当我使用num_workers > 0时,我的jupiter notebook卡住了,我花了很多时间在这个问题上没有任何答案。我没有GPU,我只使用CPU工作。

class IndexedDataset(Dataset):

def __init__(self,data,targets, test=False):
    self.dataset = data 
    if not test:
        self.labels = targets.numpy()
        self.mask =  np.concatenate((np.zeros(NUM_LABELED), np.ones(NUM_UNLABELED)))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]
        return image, self.labels[idx]
    
    def display(self, idx):
        plt.imshow(self.dataset[idx], cmap='gray')
        plt.show()

train_set = IndexedDataset(train_data, train_target, test = False)

test_set = IndexedDataset(test_data, test_target, test = True)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=2)

test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=2)

字符串
任何帮助,感激不尽。

hi3rlvi2

hi3rlvi21#

num_workers大于0时,PyTorch使用多个进程进行数据加载。
多字节笔记本在多处理方面存在已知问题。
解决这个问题的一种方法是不使用MyMyMyter Notebook-只需编写一个普通的.py文件并通过命令行运行它。
或者尝试使用这里建议的:Jupyter notebook never finishes processing using multiprocessing (Python 3)

iaqfqrcu

iaqfqrcu2#

由于jupyter Notebook不支持python多处理,因此有两个瘦库,您应该安装其中一个,如本文所述12
我更喜欢用两种方法来解决我的问题,而不使用任何外部库:
1.通过将我的文件从.ipynb格式转换为.py格式,并在终端中运行它,我在main()函数中编写代码,如下所示:

...
...

train_set = IndexedDataset(train_data, train_target, test = False)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, num_workers=4)

 if `__name__ ==  '__main__'`:
     for images,label in train_loader:
         print(images.shape)

字符串
1.使用多处理库如下:
try.ipynb中:

import multiprocessing as mp
import processing as ps

...
...

train_set = IndexedDataset(train_data, train_target, test = False)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)
    
if __name__=="__main__":
    p = mp.Pool(8)
    r = p.map(ps.getShape,train_loader) 
    print(r)
    p.close()


processing.py文件中:

def getShape(data):
    for i in data:
        return i[0].shape

y1aodyip

y1aodyip3#

至少在我的例子中,将创建DataLoader的代码放在一个单独的Python文件中解决了这个问题。一个简单的Wrapper就足够了。
例如,在dataloader_wrapper.py中创建DataLoader:

class DataloaderWrapper:
    def __init__(self, dataset, batch_size, num_workers, shuffle):
        self.dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

字符串
然后导入data loader_wrapper,通过Wrapper创建并调用data loader,如下所示:

from dataloader_wrapper import DataloaderWrapper
train_loader = DataloaderWrapper(dataset=dataset, batch_size=50000, shuffle=True, num_workers=7)

for data in train_loader.dataloader:
    pass


这是一个有点变通,但为我工作。

相关问题