我正在尝试使用AlexNet对为3S音频片段生成的声谱图图像进行分类。我已经成功地训练了我的数据集,并正在尝试识别模型错误分类的图像。
我可以通过调用www.example.com _df.adressfname来获取图像的文件名iterator.dataset.data,但是我不确定如何在get_predictions函数的for循环中使用此语句。如果我尝试使用**iterator.dataset.data_df.adressfname[i]**检索文件名,则会出现以下错误:参数0中的元素0应为Tensor,但得到的是str
最后,我想创建一个包含文件名、实际标签和预测标签的 Dataframe 。有人有什么建议吗?
class CustomDataset(Dataset):
def __init__(self, img_path, csv_file, transforms):
self.imgs_path = img_path
self.csv_train_file = csv_file
file_list = glob.glob(self.imgs_path + "*")
self.data = []
self.data_df = pd.read_csv(self.csv_train_file)
self.transforms = transforms
for ind in self.data_df.index:
img_path = self.data_df['spectrogramSegFilename'][ind]
class_name = self.data_df['dx'][ind]
self.data.append([img_path, class_name])
self.class_map = {"ProbableAD" : 0, "Control": 1}
self.img_dim = (256, 256)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_path, class_name = self.data[idx]
img = cv2.imread(img_path)
img = cv2.resize(img, self.img_dim)
class_id = self.class_map[class_name]
img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.permute(2, 0, 1)
data = self.transforms(img_tensor)
class_id = torch.tensor([class_id])
return data, class_id
if __name__ == "__main__":
transformations = transforms.Compose([transforms.ToPILImage(), transforms.Resize(256), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize((0.49966475, 0.1840554, 0.34930056), (0.35317238, 0.17343724, 0.1894943))])
train_dataset = CustomDataset("/spectrogram_images/spectrogram_train/", "train_features_segmented.csv", transformations)
test_dataset = CustomDataset("/spectrogram_images/spectrogram_test/", "test_features_segmented.csv", transformations)
train_dataset = CustomDataset("spectrogram_images/spectrogram_train/", "/train_features_segmented.csv")
test_dataset = CustomDataset("spectrogram_images/spectrogram_test/", "/test_features_segmented.csv")
train_data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
def get_predictions(model, iterator, device):
model.eval()
images = []
labels = []
probs = []
participant_ids = []
with torch.no_grad():
for i, (x, y) in enumerate(iterator):
x = x.to(device)
y_pred = model(x)
y_prob = F.softmax(y_pred, dim=-1)
participant_ids.append(iterator.dataset.data_df.adressfname[i])
images.append(x.cpu())
labels.append(y.cpu())
probs.append(y_prob.cpu())
images = torch.cat(images, dim=0)
labels = torch.cat(labels, dim=0)
probs = torch.cat(probs, dim=0)
participant_ids = torch.cat(participant_ids, dim=0)
return images, labels, probs, participant_ids
1条答案
按热度按时间oaxa6hgo1#
只需向
__getitem__()
的返回签名添加第三个字段:然后,当您调用数据加载器时:
注意,在当前代码块中,
i
索引批索引,这与 Package 在dataloader
中的dataset
对象没有任何关系(因为数据加载器具有批量大小从而它将来自数据集的多个元素整理成单个索引,并且因为dataloader会重排数据集中的元素),所以不应使用此索引(i
)引用数据集对象。如果要引用dataset
对象中的底层数据项,只需在__getitem__()
中返回数据idx
即可:然后可以使用这个
idx
引用原始的dataset
数据,并通过这种方式获得文件名。