tensorflow 3D稀疏Tensor矩阵乘法与2DTensor:无效参数错误:Tensor'a_shape'必须有2个元素 [操作:SparseTensorDenseMatMul]

z18hc3ub  于 4个月前  发布在  其他
关注(0)|答案(3)|浏览(84)

系统信息

  • TensorFlow版本2.4:

尝试使用2DTensor进行3D稀疏Tensor矩阵乘法。这是一个玩具示例:

import tensorflow as tf
import numpy as np

a = np.array([[[1., 0., 2., 0.],
              [3., 0., 0., 4.]]])
b = (np.array([1., 2.])[:,np.newaxis]).T

a_t = tf.constant(a)
b_t = tf.constant(b)

a_s = tf.sparse.from_dense(a_t)

tf.sparse.sparse_dense_matmul(b_t,a_s)

预期结果(1, 1, 4):
[[[7., 0., 2., 8.]]]
但实际上输出了一些错误:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-30-2ee379f1d3e8> in <module>
----> 1 tf.sparse.sparse_dense_matmul(b_t,a_s)

/Users/Mine/Python/tf2_4_env/lib/python3.6/site-packages/tensorflow/python/ops/sparse_ops.py in sparse_tensor_dense_matmul(sp_a, b, adjoint_a, adjoint_b, name)
   2564       return array_ops.transpose(
   2565           sparse_tensor_dense_matmul(
-> 2566               b, sp_a, adjoint_a=not adjoint_a, adjoint_b=not adjoint_b))
   2567 
   2568   else:

/Users/Mine/Python/tf2_4_env/lib/python3.6/site-packages/tensorflow/python/ops/sparse_ops.py in sparse_tensor_dense_matmul(sp_a, b, adjoint_a, adjoint_b, name)
   2577           b=b,
   2578           adjoint_a=adjoint_a,
-> 2579           adjoint_b=adjoint_b)
   2580 
   2581 

/Users/Mine/Python/tf2_4_env/lib/python3.6/site-packages/tensorflow/python/ops/gen_sparse_ops.py in sparse_tensor_dense_mat_mul(a_indices, a_values, a_shape, b, adjoint_a, adjoint_b, name)
   3049       return _result
   3050     except _core._NotOkStatusException as e:
-> 3051       _ops.raise_from_not_ok_status(e, name)
   3052     except _core._FallbackException:
   3053       pass

/Users/Mine/Python/tf2_4_env/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6860   message = e.message + (" name: " + name if name is not None else "")
   6861   # pylint: disable=protected-access
-> 6862   six.raise_from(core._status_to_exception(e.code, message), None)
   6863   # pylint: enable=protected-access
   6864 

/Users/Mine/Python/tf2_4_env/lib/python3.6/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Tensor 'a_shape' must have 2 elements [Op:SparseTensorDenseMatMul]

请问是否可以添加3D稀疏Tensor矩阵乘法与2DTensor的可能性?类似于3DTensor与2DTensor的乘法

mitkmikd

mitkmikd1#

我也有同样的问题。你找到答案了吗?

9wbgstp7

9wbgstp73#

我想知道这个问题是否已经解决。
可以使用tf.map_fn或tf.vectorized_map在批轴上迭代tf.sparse.sparse_dense_matmul()与2DTensor。但不确定这是否是一个计算效率高的解决方案。

相关问题