在第一次生成第一个单词后,present不会更新。
with tf.variable_scope(scope):
c = conv1d(x, 'c_attn', n_state*3)
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
**present = tf.stack([k, v], axis=1)**
if past is not None:
pk, pv = tf.unstack(past, axis=1)
k = tf.concat([pk, k], axis=-2)
v = tf.concat([pv, v], axis=-2)
a = multihead_attn(q, k, v)
a = merge_heads(a)
a = conv1d(a, 'c_proj', n_state)
return a, present
请给我换成
with tf.variable_scope(scope):
c = conv1d(x, 'c_attn', n_state*3)
q, k, v = map(split_heads, tf.split(c, 3, axis=2))
if past is not None:
pk, pv = tf.unstack(past, axis=1)
k = tf.concat([pk, k], axis=-2)
v = tf.concat([pv, v], axis=-2)
**present = tf.stack([k, v], axis=1)**
a = multihead_attn(q, k, v)
a = merge_heads(a)
a = conv1d(a, 'c_proj', n_state)
return a, present
1条答案
按热度按时间qnzebej01#
:-)他们是在sample.py while循环中这样做的。他们的代码是如此高效,他们不想完全缓存它,同时用新的预测词计算logit。
在sample.py有一行,如果过去是没有其他tf.concat