请提出你的问题 Please ask your question
现象:
使用Pass替换Op之后,在函数:ProgramInterpreter::RunOperator 中添加如下打印发现:
发现在替换的Op后面,会多非预期的assgin算子
该Pass目的是对GPT3 Attention部分进行算子融合
Pass的pattern和replace代码:
def gen_fuse_attention_cached_layer():
def pattern(embeddings, norm_weight, norm_bias, mix_linear_weight, mix_linear_bias,
self_out_linear_weight, self_out_linear_bias, self_out_norm_weight, self_out_norm_bias,
ffn_linear_weight, ffn_linear_bias, ffn_out_linear_weight, ffn_out_linear_bias, attn_mask, past_key, past_value):
# layernorm
x = layernorm_only_out(embeddings, norm_weight, norm_bias)
# linear
x = linear_with_params(x, mix_linear_weight, mix_linear_bias)
# reshape
x = reshape_without_shape_not_set_attr(x)
# split
q, k, v = split_3(x)
# # q transpose
transed_q = transpose_without_shape(q)
concated_k = concat_in_axis_1(past_key, k)
transed_k = transpose_without_shape(concated_k)
concated_v = concat_in_axis_1(past_value, v)
transed_v = transpose_without_shape(concated_v)
scaled_q = paddle.incubate.passes.ir.PassDesc.OP.scale(X=transed_q)
q_mul_k = paddle.incubate.passes.ir.PassDesc.OP.matmul_v2(X=scaled_q, Y=transed_k)
q_mul_k.SetAttr("trans_y", True)
scaled_q_mul_k = paddle.incubate.passes.ir.PassDesc.OP.scale(X=q_mul_k)
added_attn_weight = paddle.incubate.passes.ir.PassDesc.OP.elementwise_add(X=scaled_q_mul_k, Y=attn_mask)
softmax_attn_weight = paddle.incubate.passes.ir.PassDesc.OP.softmax(X=added_attn_weight)
dropout_op = paddle.incubate.passes.ir.PassDesc.OP.dropout
dropout_op.SetAttr("dropout_implementation", "upscale_in_train")
dropout_op._outputs.pop("Mask")
softmax_attn_weight = dropout_op(X=softmax_attn_weight)
out = paddle.incubate.passes.ir.PassDesc.OP.matmul_v2(X=softmax_attn_weight, Y=transed_v)
out = transpose_without_shape(out)
out = reshape_without_shape(out)
# linear
out = linear_with_params(out, self_out_linear_weight, self_out_linear_bias)
dropout_op_2 = paddle.incubate.passes.ir.PassDesc.OP.dropout
dropout_op_2.SetAttr("dropout_implementation", "upscale_in_train")
dropout_op_2._outputs.pop("Mask")
out = dropout_op_2(X=out)
#resadd
res_add_out = paddle.incubate.passes.ir.PassDesc.OP.elementwise_add(X=embeddings, Y=out)
layer_out = layernorm_only_out(res_add_out, self_out_norm_weight, self_out_norm_bias)
linear_out = linear_with_params(layer_out, ffn_linear_weight, ffn_linear_bias)
gelu_out = paddle.incubate.passes.ir.PassDesc.OP.gelu(X=linear_out)
linear_2_out = linear_with_params(gelu_out, ffn_out_linear_weight, ffn_out_linear_bias)
dropout_op_3 = paddle.incubate.passes.ir.PassDesc.OP.dropout
dropout_op_3.SetAttr("dropout_implementation", "upscale_in_train")
dropout_op_3._outputs.pop("Mask")
drop_3_out = dropout_op_3(X=linear_2_out)
# # residule
res_add_out_2 = paddle.incubate.passes.ir.PassDesc.OP.elementwise_add(X=res_add_out, Y=drop_3_out)
return res_add_out_2, concated_k, concated_v
def gpt3_layer_cache_adaptor(embeddings, norm_weight, norm_bias, mix_linear_weight, mix_linear_bias,
self_out_linear_weight, self_out_linear_bias, self_out_norm_weight, self_out_norm_bias,
ffn_linear_weight, ffn_linear_bias, ffn_out_linear_weight, ffn_out_linear_bias, attn_mask, past_key, past_value):
gpt3_layer_cache_op = paddle.incubate.passes.ir.PassDesc.OP.gpt3_layer
gpt3_layer_cache_op._outputs = {}
gpt3_layer_cache_op(
Hidden=embeddings,
NormWeight=norm_weight,
NormBias=norm_bias,
MixLinearWeight=mix_linear_weight,
MixLinearBias=mix_linear_bias,
SelfOutLinearWeight=self_out_linear_weight,
SelfOutLinearBias=self_out_linear_bias,
SelfOutNormWeight=self_out_norm_weight,
SelfOutNormBias=self_out_norm_bias,
FfnLinearWeight=ffn_linear_weight,
FfnLinearBias=ffn_linear_bias,
FfnOutLinearWeight=ffn_out_linear_weight,
FfnOutLinearBias=ffn_out_linear_bias,
AttentionMask=attn_mask,
PastKey=past_key,
PastValue=past_value)
outs_name = [paddle.fluid.unique_name.generate('gpt3_layer') for i in range(3)] # 3 outputs
print(outs_name)
gpt3_layer_cache_op._desc.set_output("Out", [outs_name[0]])
gpt3_layer_cache_op._desc.set_output("PresentKey", [outs_name[1]])
gpt3_layer_cache_op._desc.set_output("PresentValue", [outs_name[2]])
block = paddle.static.default_main_program().current_block()
results = []
for out in outs_name:
results.append(block.create_var(name=out))
return results[0], results[1], results[2]
def replace(embeddings, norm_weight, norm_bias, mix_linear_weight, mix_linear_bias,
self_out_linear_weight, self_out_linear_bias, self_out_norm_weight, self_out_norm_bias,
ffn_linear_weight, ffn_linear_bias, ffn_out_linear_weight, ffn_out_linear_bias, attn_mask, past_key, past_value):
out = gpt3_layer_cache_adaptor(embeddings, norm_weight, norm_bias, mix_linear_weight, mix_linear_bias,
self_out_linear_weight, self_out_linear_bias, self_out_norm_weight, self_out_norm_bias,
ffn_linear_weight, ffn_linear_bias, ffn_out_linear_weight, ffn_out_linear_bias, attn_mask, past_key, past_value)
return out[0], out[1], out[2]
return pattern, replace
因为while op里面的替换效果显示不出,使用while op之外的全量layer 替换效果代替:
1条答案
按热度按时间kq0g1dla1#
建议先通过custom pass的方式绕过去,后续我们需要定位下是否是框架的问题并修复 @ronny1996