tensorflow 如何向einsum操作添加另一个维度?

cczfrluj  于 2023-08-06  发布在  其他
关注(0)|答案(1)|浏览(102)

假设tensortensor1是代码片段中提供的形状对输入的一些计算转换。einsum运算执行爱因斯坦求和,以特定顺序聚合结果。

import tensorflow as tf

tf.random.set_seed(0)

tensor = tf.random.uniform(shape=(2, 2, 2)) # Shape: (n_nodes, n_nodes, n_heads)
tensor1 = tf.random.uniform(shape=(2, 2, 2)) # Shape: (n_nodes, n_heads, n_units)

print(tensor)
print("-" * 50)
print(tensor1)
print("-" * 50)

einsum_tensor = tf.einsum('ijh, jhu -> ihu', tensor, tensor1) # Shape: (n_nodes, n_heads, n_units)

print(einsum_tensor)

字符串
如果添加批处理维度,如何修改einsum操作?如果存在批尺寸,意味着新形状将是:

tensor shape: (batch_size, n_nodes, n_nodes, n_heads)
tensor1 shape: (batch_size, n_nodes, n_heads, n_units)
output shape: (batch_size, n_nodes, n_heads, n_units)


我想到了下面的修改,但我不知道是不是真的。我从原始操作中了解到,jh是伪索引,iu是自由索引。

einsum_tensor = tf.einsum('bijh, bjhu -> bihu', tensor, tensor1)


这个guide是我正在使用的引用(第228行)。请注意,我已经将指南中的f更改为u
P.S:我在人工智能堆栈上问过这个问题,但他们建议这是一个编程问题,应该在这里问。

kr98yfug

kr98yfug1#

您可以像这样使用点运算符

tf.einsum('...ijh, ...jhu -> ...ihu', tensor, tensor1)

字符串
看看documentation

相关问题