如何获取测试数据集中错误分类图像的文件名- pytorch?

jhiyze9q  于 2023-02-04  发布在  其他
关注(0)|答案(1)|浏览(220)

我正在尝试使用AlexNet对为3S音频片段生成的声谱图图像进行分类。我已经成功地训练了我的数据集,并正在尝试识别模型错误分类的图像。
我可以通过调用www.example.com _df.adressfname来获取图像的文件名iterator.dataset.data,但是我不确定如何在get_predictions函数的for循环中使用此语句。如果我尝试使用**iterator.dataset.data_df.adressfname[i]**检索文件名,则会出现以下错误:参数0中的元素0应为Tensor,但得到的是str
最后,我想创建一个包含文件名、实际标签和预测标签的 Dataframe 。有人有什么建议吗?

class CustomDataset(Dataset):
  def __init__(self, img_path, csv_file, transforms):
    self.imgs_path = img_path
    self.csv_train_file = csv_file
    file_list = glob.glob(self.imgs_path + "*")
    self.data = []
    self.data_df = pd.read_csv(self.csv_train_file)
    self.transforms = transforms

    for ind in self.data_df.index:
      img_path = self.data_df['spectrogramSegFilename'][ind]
      class_name = self.data_df['dx'][ind]
      self.data.append([img_path, class_name])
    self.class_map = {"ProbableAD" : 0, "Control": 1}
    self.img_dim = (256, 256)

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img_path, class_name = self.data[idx]
    img = cv2.imread(img_path)
    img = cv2.resize(img, self.img_dim)
    class_id = self.class_map[class_name]
    img_tensor = torch.from_numpy(img)
    img_tensor = img_tensor.permute(2, 0, 1)
    data = self.transforms(img_tensor)
    class_id = torch.tensor([class_id])
    return data, class_id

if __name__ == "__main__":
    transformations = transforms.Compose([transforms.ToPILImage(), transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.49966475, 0.1840554, 0.34930056), (0.35317238, 0.17343724, 0.1894943))])
    train_dataset = CustomDataset("/spectrogram_images/spectrogram_train/", "train_features_segmented.csv", transformations)
    test_dataset = CustomDataset("/spectrogram_images/spectrogram_test/", "test_features_segmented.csv", transformations)
    train_dataset = CustomDataset("spectrogram_images/spectrogram_train/", "/train_features_segmented.csv")
    test_dataset = CustomDataset("spectrogram_images/spectrogram_test/", "/test_features_segmented.csv")
    train_data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_data_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

def get_predictions(model, iterator, device):

    model.eval()

    images = []
    labels = []
    probs = []
    participant_ids = []

    with torch.no_grad():

      for i, (x, y) in enumerate(iterator): 

            x = x.to(device)
            y_pred = model(x)
            y_prob = F.softmax(y_pred, dim=-1)
            participant_ids.append(iterator.dataset.data_df.adressfname[i])
            images.append(x.cpu())
            labels.append(y.cpu())
            probs.append(y_prob.cpu())
            
    images = torch.cat(images, dim=0)
    labels = torch.cat(labels, dim=0)
    probs = torch.cat(probs, dim=0)
    participant_ids = torch.cat(participant_ids, dim=0)

    return images, labels, probs, participant_ids
oaxa6hgo

oaxa6hgo1#

只需向__getitem__()的返回签名添加第三个字段:

def __getitem__(self,idx): 
    ...
    return data,class_id,img_path

然后,当您调用数据加载器时:

for i, (x,y,img_paths) in enumerate(iterator):

    ... call model ...

    ... compare outputs to labels ...

    ... identify incorrect batch indices

    mislabeled_files = [img_paths[idx] for idx in incorrect_batch_indices]

注意,在当前代码块中,i索引批索引,这与 Package 在dataloader中的dataset对象没有任何关系(因为数据加载器具有批量大小从而它将来自数据集的多个元素整理成单个索引,并且因为dataloader会重排数据集中的元素),所以不应使用此索引(i)引用数据集对象。如果要引用dataset对象中的底层数据项,只需在__getitem__()中返回数据idx即可:

def __getitem__(self,idx): 
    ...
    return data,class_id,idx

然后可以使用这个idx引用原始的dataset数据,并通过这种方式获得文件名。

相关问题