我有一个医学图像的数据集(.dcm
),可以批量读入TensorFlow
,但是我面临的问题是这些图像的标签在一个.csv
中,.csv
文件包含两列-image_path
(图像的位置)和image_labels
(0表示否;我想知道如何批量读取标签到TensorFlow
数据集中。我使用以下代码批量加载图像:-
import tensorflow as tf
import tensorflow_io as tfio
def process_image(filename):
image_bytes = tf.io.read_file(filename)
image = tf.squeeze(
tfio.image.decode_dicom_image(image_bytes, on_error='strict', dtype=tf.uint16),
axis = 0
)
x = tfio.image.decode_dicom_data(image_bytes, tfio.image.dicom_tags.PhotometricInterpretation)
image = (image - tf.reduce_min(image))/(tf.reduce_max(image) - tf.reduce_min(image))
if(x == "MONOCHROME1"):
image = 1 - image
image = image*255
image = tf.cast(tf.image.resize(image, (512, 512)),tf.uint8)
return image
# train_images is a list containing the locations of .dcm images
dataset = tf.data.Dataset.from_tensor_slices(train_images)
dataset = dataset.map(process_image, num_parallel_calls=4).batch(50)
因此,我可以将图像加载到TensorFlow
数据集中,但我想知道如何批量加载图像标签。
1条答案
按热度按时间hi3rlvi21#
用下面这样的代码代替最后两行应该可以工作:
现在,
dataset
可以传递给网络的.fit()
、.predict()
和其他方法:或者,你可以创建第二个包含标签的数据集,然后用www.example.com()将两个数据集合并tf.data.Dataset.zip,它的工作原理类似于python的原生
zip
。我更喜欢第一种方法,因为我觉得它更干净一些+我可以,例如,打乱文件名/标签,然后才解析文件,而不是相反。