tensorflow 无法使用Tensor的真实的维度重塑Tensor

e5njpo68  于 2023-03-09  发布在  其他
关注(0)|答案(1)|浏览(202)

我有一个非常复杂的变压器模型,我需要从头开始计算MRR。

  • (我认为数据集批处理(1248 = 3932)存在问题,但我没有足够的专业知识来解决它)

所以我写的代码(分开的问题行):

def _count_mrr(self, y_true: tf.Tensor, y_pred: tf.Tensor):
        y_true = tf.reshape(y_true, shape=(1, self.max_length - 1))
        y_pred = tf.reshape(y_pred, shape=(1, self.max_length - 1, self._data_controller._vocab_size))
        y_true = tf.cast(y_true, dtype=tf.dtypes.int32)
        y_true = tf.squeeze(y_true)
        y_pred = tf.squeeze(y_pred)

        y_true = tf.reshape(y_true, shape=(self.max_length - 1, 1))

        y_pred = tf.math.top_k(y_pred, self._data_controller._vocab_size).indices
        where_tensor = tf.equal(y_pred, y_true)
        where_tensor = tf.where(where_tensor)[:, 1]
        where_tensor = tf.cast(tf.add(where_tensor, 1), dtype=tf.dtypes.float64)
        where_tensor = tf.divide(tf.constant(np.ones(self.max_length - 1)), where_tensor)
        return tf.math.reduce_mean(where_tensor)

当我尝试运行这段代码时,它出现了错误:

File "d:\Development\transformer_chatbot\chatbot\transformer.py", line 211, in _count_mrr
      y_true = tf.reshape(y_true, shape=(self.max_length - 1, 1))
Node: 'Reshape_4'
Input to reshape is a tensor with 1248 values, but the requested shape has 39
     [[{{node Reshape_4}}]] [Op:__inference_train_function_39181]

但是!!!如果我试着运行y_true = tf.reshape(y_true, shape=(1248, 1)),我会得到这个:

File "d:\Development\transformer_chatbot\chatbot\transformer.py", line 211, in _count_mrr  *
        y_true = tf.reshape(y_true, shape=(1248, 1))

    ValueError: Cannot reshape a tensor with 39 elements to shape [1248,1] (1248 elements) for 
    '{{node Reshape_4}} = Reshape[T=DT_INT32, Tshape=DT_INT32](Squeeze, Reshape_4/shape)' with
    input shapes: [39], [2] and with input tensors computed as partial shapes: input[1] = [1248,1].

The full model if needed

nnsrf1az

nnsrf1az1#

正如我之前提到的,问题出在使用tf.Dataset批处理时。
数据在3个维度内加载(如(None, maxlen-1, vocab_size)),其中None是隐藏的批量大小。
一开始tf加载2个空的尝试,批大小为1,以检查是否一切正常,但实际的批大小是32。
我用一个更简化的tf.reshape()组合来解决这个问题。我把我所有的批处理都分解成一个单元,比如:

start = tf.Tensor( [
            [[1, 2, 3], [1, 2, 3]], 
            [[1, 2, 3], [1, 2, 3]]
            ], shape=(None (2),2,3))
end = tf.reshape(start, shape=(-1, 3))
>> tf.Tensor( [
            [1, 2, 3], 
            [1, 2, 3], 
            [1, 2, 3],
            [1, 2, 3],
            ], shape=(None (4),3))

然后用它来工作。
最终修复:

def _count_mrr(self, y_true: tf.Tensor, y_pred: tf.Tensor):
    y_true = tf.reshape(y_true, shape=(-1, 1))
    y_pred = tf.reshape(y_pred, shape=(-1, self._data_controller._vocab_size))
    y_pred = tf.math.top_k(y_pred, k=self._data_controller._vocab_size).indices
    y_true = tf.cast(y_true, dtype=tf.dtypes.int32)
    where_tensor = tf.equal(y_true,y_pred)
    where_tensor = tf.where(where_tensor)[:, 1]
    where_tensor = tf.cast(tf.add(where_tensor, 1), dtype=tf.dtypes.float64)
    where_tensor = tf.divide(
               tf.constant(np.ones(where_tensor._shape_as_list()[0])), 
               where_tensor)
    return tf.math.reduce_mean(where_tensor)

注意,这个解决方案并不完美,需要为Tensor分配大量的内存

  • 我希望我的经历能帮助其他有同样问题的人 *

相关问题