python tensorflow Map函数到多Tensor

s4n0splo  于 2023-03-28  发布在  Python
关注(0)|答案(1)|浏览(94)

我在TensorFlow的自定义层中使用以下函数来重新排列查询,键值:

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

它会发出警告
WARNING:tensorflow:From /usr/local/lib/python3.9/dist-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
有没有更TensorFlowic的方法来做到这一点?
我尝试使用map_fn如下,它抛出相同的警告和错误:

import tensorflow as tf
from einops import rearrange
a = tf.random.uniform((1, 196, 196))
b, c, d = tf.map_fn(lambda t: rearrange(t, 'b (h n) d -> b h n d', h=14), [a, a, a])

从文档中看,似乎是tf.map_fn,但它似乎可以在Tensor堆栈上工作。将Tensor堆栈会更好吗?

bnl4lu3b

bnl4lu3b1#

几个错误-首先,您需要将einops tensorflow layer用于rearrange。下面的更改应该可以工作:

from einops.layers.tensorflow import Rearrange

a = tf.random.uniform(shape=(1, 196, 196))
q, k, v = tf.map_fn(fn=lambda t: Rearrange('b (h n) d -> b h n d', h=14)(t), elems=tf.stack([a, a, a]))
print(q.shape, k.shape, v.shape)
# (1, 14, 14, 196) (1, 14, 14, 196) (1, 14, 14, 196)

相关问题