我有一个非常复杂的变压器模型,我需要从头开始计算MRR。
- (我认为数据集批处理(1248 = 3932)存在问题,但我没有足够的专业知识来解决它)
所以我写的代码(分开的问题行):
def _count_mrr(self, y_true: tf.Tensor, y_pred: tf.Tensor):
y_true = tf.reshape(y_true, shape=(1, self.max_length - 1))
y_pred = tf.reshape(y_pred, shape=(1, self.max_length - 1, self._data_controller._vocab_size))
y_true = tf.cast(y_true, dtype=tf.dtypes.int32)
y_true = tf.squeeze(y_true)
y_pred = tf.squeeze(y_pred)
y_true = tf.reshape(y_true, shape=(self.max_length - 1, 1))
y_pred = tf.math.top_k(y_pred, self._data_controller._vocab_size).indices
where_tensor = tf.equal(y_pred, y_true)
where_tensor = tf.where(where_tensor)[:, 1]
where_tensor = tf.cast(tf.add(where_tensor, 1), dtype=tf.dtypes.float64)
where_tensor = tf.divide(tf.constant(np.ones(self.max_length - 1)), where_tensor)
return tf.math.reduce_mean(where_tensor)
当我尝试运行这段代码时,它出现了错误:
File "d:\Development\transformer_chatbot\chatbot\transformer.py", line 211, in _count_mrr
y_true = tf.reshape(y_true, shape=(self.max_length - 1, 1))
Node: 'Reshape_4'
Input to reshape is a tensor with 1248 values, but the requested shape has 39
[[{{node Reshape_4}}]] [Op:__inference_train_function_39181]
但是!!!如果我试着运行y_true = tf.reshape(y_true, shape=(1248, 1))
,我会得到这个:
File "d:\Development\transformer_chatbot\chatbot\transformer.py", line 211, in _count_mrr *
y_true = tf.reshape(y_true, shape=(1248, 1))
ValueError: Cannot reshape a tensor with 39 elements to shape [1248,1] (1248 elements) for
'{{node Reshape_4}} = Reshape[T=DT_INT32, Tshape=DT_INT32](Squeeze, Reshape_4/shape)' with
input shapes: [39], [2] and with input tensors computed as partial shapes: input[1] = [1248,1].
1条答案
按热度按时间nnsrf1az1#
正如我之前提到的,问题出在使用
tf.Dataset
批处理时。数据在3个维度内加载(如
(None, maxlen-1, vocab_size)
),其中None
是隐藏的批量大小。一开始
tf
加载2个空的尝试,批大小为1,以检查是否一切正常,但实际的批大小是32。我用一个更简化的
tf.reshape()
组合来解决这个问题。我把我所有的批处理都分解成一个单元,比如:然后用它来工作。
最终修复:
注意,这个解决方案并不完美,需要为Tensor分配大量的内存。