pytorch 为什么iter(dataloader)会卡住而不停止?(__main__相关的属性错误)

7vux5j2d  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(396)

I am trying to do some revision about machine learning. This is some tutorial exercise I did before. It dont have problem at that time. However, now it takes forever to run dataiter = iter(train_loader) . It just stuck and would not stop. It doesn't have problem at google colab which run Python 3.7.15 but I am in Python 3.8.11.
Here is the size of my dataset

Length of Dataset is 1470
Full: 1470
Train: 940
Valid: 236
Test: 294
import multiprocessing as mp

bs = 32

# num_cpu = 2

num_cpu = mp.cpu_count()

train_loader = DataLoader(train, batch_size=bs, shuffle=True, num_workers=num_cpu, pin_memory=True)
valid_loader = DataLoader(valid, batch_size=bs, shuffle=False, num_workers=num_cpu, pin_memory=True)
test_loader = DataLoader(test, batch_size=bs, shuffle=False, num_workers=num_cpu, pin_memory=True)

After that I run dataiter = iter(train_loader)
These is some error msg

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/kelvin/opt/anaconda3/envs/torch-gpu/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/kelvin/opt/anaconda3/envs/torch-gpu/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'HrDataset' on <module '__main__' (built-in)>

AttributeError

class HrDataset(Dataset):

    def __init__(self, file_path):
        print('HrDataset is loading {}'.format(file_path))
        df = pd.read_csv(file_path)
        self.df = df
        self.df = self.preprocessing(df)
        print("Preprocessing is completed")
        print('Length of HrDataset is {}'.format(len(self.df)))

    def __getitem__(self, idx):
        X = np.array(self.df.iloc[idx, 1:]).astype(np.float32)
        y = self.df.iloc[idx, 0]

        return X, y

    def __len__(self):
        return len(df)

    def preprocessing(self, df):
        for col in df.columns:
            if df.dtypes[col] == 'object':
                df[col] = df[col].fillna('NA')
                df[col] = df[col].astype('category')
                if len(df[col].cat.categories) > 2:
                    df = pd.get_dummies(df, columns=[col])
                else:
                    df[col] = LabelEncoder().fit_transform(df[col])
            else:
                df[col] = df[col].fillna(0)
        return df
mzmfm0qo

mzmfm0qo1#

我终于找到了问题所在。我在一个IPYNB文件中定义了“HrDataset”。它似乎找不到我为处理python多处理而定义的对象。我在一个文件HrDataset.py中定义了该对象,然后通过from HrDataset import HrDataset调用它来修复这个问题。
参考编号:Multiprocessing example giving AttributeError

相关问题