tensorflow 二维矩阵的tf.diag

pw9qyyiw  于 2022-12-19  发布在  其他
关注(0)|答案(2)|浏览(199)

我有一个2DTensor,它具有各种阵列,定义如下:

x = tf.constant([[0,1,2],[-1,0,1],[-1,-2,0]])

我想把每个数组转换成一个对角矩阵

diag_x =
[[[ 0,  0,  0],
  [ 0,  1,  0],
  [ 0,  0,  2]],
 [[-1,  0,  0],
  [ 0,  0,  0],
  [ 0,  0,  1]],
 [[-2,  0,  0],
  [ 0, -1,  0],
  [ 0,  0,  0]]]

但是如果我使用操作 * tf.diag(x)*,输出不是这个。

xxslljrj

xxslljrj1#

我终于找到了解决办法:

tf.matrix_diag(x)
tkclm6bt

tkclm6bt2#

编辑:对于TF 2.0,您可以使用

tf.linalg.diag(x)

您可以尝试:

tf.matrix_set_diag(tf.zeros((3,3,3), dtype=tf.int32), x)

相关问题