bounty将在4天后过期。回答此问题可获得+100声望奖励。Alberto A希望引起更多人关注此问题。
我正在Keras上为我拥有的一些数据实现一个简单的健全性检查模型。我的训练数据集由大约550个文件组成,每个文件贡献了大约150个样本。每个训练样本具有以下签名:
({'input_a': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None),
'input_b': TensorSpec(shape=(None, 900, 1), dtype=tf.float64, name=None)},
TensorSpec(shape=(None, 1), dtype=tf.int64, name=None)
)
本质上,每个训练样本由两个形状为(900,1)的输入组成,目标是单个(二进制)标签。我的模型的第一步是将输入连接成(900,2)Tensor。
训练样本的总数约为70000。
作为模型的输入,我创建了一个tf.data.Dataset,并应用了一些准备步骤:
tf.Dataset.filter
:过滤部分标签无效的样本tf.Dataset.shuffle
tf.Dataset.filter
:对训练数据集采样不足tf.Dataset.batch
步骤3是我的问题中最重要的。为了对数据集进行欠采样,我应用了一个简单的函数:
def undersampling(dataset: tf.data.Dataset, drop_proba: Iterable[float]) -> tf.data.Dataset:
def undersample_function(x, y):
drop_prob_ = tf.constant(drop_proba)
idx = y[0]
p = drop_prob_[idx]
v = tf.random.uniform(shape=(), dtype=tf.float32)
return tf.math.greater_equal(v, p)
return dataset.filter(undersample_function)
本质上,该函数接受一个概率向量drop_prob
,使得drop_prob[l]
是丢弃标签为l
的样本的概率(该函数有点复杂,但这是我发现将其实现为Dataset.filter
的方式)。使用相等的概率,比如drop_prob=[0.9, 0.9]
,我将丢弃大约90%的样本。
现在,问题是,我一直在为我的数据集尝试不同的欠采样,以便在性能和训练时间之间找到一个最佳点,但是当我欠采样时,epoch持续时间是相同的,而时间/步长增加。
保持我的batch_size
固定在20000,对于完整的数据集,我总共有4个批次,平均时间如下:
Epoch 4/1000
1/4 [======>.......................] - ETA: 9s
2/4 [==============>...............] - ETA: 5s
3/4 [=====================>........] - ETA: 2s
4/4 [==============================] - ETA: 0s
4/4 [==============================] - 21s 6s/step
而如果我用drop_prob = [0.9, 0.9]
对数据集进行欠采样(也就是说,我删除了大约90%的数据集),并保持相同的batch_size
为20000,我有1个批次,平均时间如下:
Epoch 4/1000
1/1 [==============================] - ETA: 0s
1/1 [==============================] - 22s 22s/step
请注意,虽然批次数量只有1,但纪元时间是相同的!只是处理批次需要更长的时间。
现在,作为一个健全性检查,我尝试了一种不同的欠采样方法,通过过滤文件来代替。所以我选择了大约55个训练文件(10%),在一个批次中具有类似数量的样本,并从tf.Dataset
中删除了欠采样。epoch时间按预期递减:
Epoch 4/1000
1/1 [==============================] - ETA: 0s
1/1 [==============================] - 2s 2s/step
注意,原始数据集有70014个训练样本,而通过tf.Dataset.filter的欠采样数据集有6995个样本,通过文件过滤的欠采样数据集有7018个样本,因此数量是一致的。
快得多。事实上,它所花费的时间大约是整个数据集所花费时间的10%。因此,在创建tf.Dataset
时,我执行欠采样的方式(通过使用tf.data.Dataset.filter
)存在问题,我想寻求帮助以找出问题所在。谢谢。
1条答案
按热度按时间js81xvg61#
似乎大部分时间都花在了数据集操作上,而不是网络本身。从研究证据来看,我的理论是,如果这是在GPU上执行的(数据集操作是在CPU上执行的),那么GPU必须在批处理之间等待数据集。因此,由于数据集操作总是花费相同的时间,这就是为什么在进度条上看起来批处理需要更长的时间。
如果在GPU上执行,Assert这个理论是否正确的正确方法是观察GPU利用率(您可以在
watch -n 0.5 nvidia-smi
运行时使用它,或者更好地使用nvtop
或任何其他GPU监控工具)。(不是内存!而是利用率)不接近100%,那么这将是一个指标,这确实是问题。注意,它不应该从90%下降,甚至不是半秒。要解决这个问题,您应该使用
Dataset.prefetch
作为代码中的最后一个数据集操作,这将导致CPU过度获取批处理,这样它就有批处理可供网络使用,从而不会等待。