tensorflow 如何加快tf.数据集的创建和训练?

u5rb5r59  于 2022-11-16  发布在  其他
关注(0)|答案(1)|浏览(146)

我正在尝试创建一个用于训练目的的标签数据集。在这个数据集中,我想用一个描述点位置的特定数字来标记视频帧。
由于视频总长度约为1小时每秒25帧(fps),因此总帧数超过100,000,仅创建数据集就需要一天多的时间。
我使用tf.keras.image_dataset_from_directory创建数据集,方法是直接阅读文件,然后通过我创建的 *numpy数组 * 为其分配一个标签。我想问一下,是否有更快的方法来创建数据集或以并行方式合并tf.data. dataset。
下面是我的代码。(我使用Jupyter笔记本上的VSCode和Python 3.9.7和tf.版本=2.0):

import numpy as np
import tensorflow as tf

def print_dataset(data_set:tf.data.Dataset):
        iterator = data_set.as_numpy_iterator()
        # create a numpy array labelled
        images = np.empty((0, 108, 192, 1))
        labels = np.empty((0))
        for element in iterator:
                images = np.append(images, element[0], axis=0)
                labels = np.append(labels, element[1], axis=0)
              
        print(np.shape(images))
        print(np.shape(labels))
        plt.imshow(images[0])
        print(labels[0])
        return images, labels

# creating the frame dataset from directory
image = tf.keras.utils.image_dataset_from_directory(
    'frames', labels=labels, label_mode='int', image_size=(108,192), color_mode='grayscale',
    batch_size=1) # frames directory contains the frames used for training.

print(image)
plt.show(image)

#generating the labelled dataset and getting the img as well as labels
img, labels= print_dataset(image)
)
6ie5vjzr

6ie5vjzr1#

有两种方法可以实现这一点,第一种方法是通过生成器实现,但该过程成本较高,还有另一种方法称为使用tf.data进行更精细的控制。您可以在此链接查看此方法
https://www.tensorflow.org/tutorials/load_data/images
但是,我将向您展示一个简短的演示,说明如何更快地加载图像......所以,让我们开始......

#First import some libraries which are needed
import os
import tensorflow as tf
import matplotlib.pyplot as plt

我只上两门“猫”和“狗”的课。你可以上两门以上的课...

batch_size = 32
img_height = 180
img_width = 180

#define your data directory where your dataset is placed

data_dir = path to your dataset folder

#Now, here define a list of names for your dataset, like I am only loading cats and dogs... you can fill it with more if you have more

#Now, glob the list of images in these two directories (cats & Dogs)
list_files = tf.data.Dataset.list_files(data_dir + '/*/*.jpg', shuffle=None)

image_count = len(list_files)

#Now, define your class names to label your dataset later...
class_names = ['cats', 'dogs']

#Now, here define the validation, test, train, etc.

val_size = int(image_count * 0.2)
train_ds = list_files.skip(val_size)
val_ds = list_files.take(val_size)

#To get labels
def get_label(file_path):
  # Convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  one_hot = parts[1] == class_names
  # Integer encode the label
  return tf.argmax(one_hot)

def decode_img(img):
  # Convert the compressed string to a 3D uint8 tensor
  img = tf.io.decode_jpeg(img, channels=3)
  # Resize the image to the desired size
  return tf.image.resize(img, [img_height, img_width])

def process_path(file_path):
  label = get_label(file_path)
  # Load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label

#Use Dataset.map to create a dataset of image, label pairs:
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_ds = train_ds.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=tf.data.AUTOTUNE)

#Configure the dataset for performance, increase the buffer-size if you have a lot of data...
def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)
  return ds

train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)

#Visualize the data
image_batch, label_batch = next(iter(train_ds))

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].numpy().astype("uint8"))
  label = label_batch[i]
  plt.title(class_names[label])
  plt.axis("off")

输出:

COLAB文件的链接为:
https://colab.research.google.com/drive/1oUNuGVDWDLqwt_YQ80X-CBRL6kJ_YhUX?usp=sharing

相关问题