Paddle 使用Pass替换Op之后,替换Op之后多了assign算子

bq9c1y66  于 5个月前  发布在  其他
关注(0)|答案(1)|浏览(54)

请提出你的问题 Please ask your question

现象:

使用Pass替换Op之后,在函数:ProgramInterpreter::RunOperator 中添加如下打印发现:

发现在替换的Op后面,会多非预期的assgin算子

该Pass目的是对GPT3 Attention部分进行算子融合

Pass的pattern和replace代码:

def gen_fuse_attention_cached_layer():
    def pattern(embeddings, norm_weight, norm_bias, mix_linear_weight, mix_linear_bias,
      self_out_linear_weight, self_out_linear_bias, self_out_norm_weight, self_out_norm_bias,
      ffn_linear_weight, ffn_linear_bias, ffn_out_linear_weight, ffn_out_linear_bias, attn_mask, past_key, past_value):
        # layernorm
        x = layernorm_only_out(embeddings, norm_weight, norm_bias)
        # linear
        x = linear_with_params(x, mix_linear_weight, mix_linear_bias)
        # reshape
        x = reshape_without_shape_not_set_attr(x)
        # split
        q, k, v = split_3(x)

        # # q transpose
        transed_q = transpose_without_shape(q)

        concated_k = concat_in_axis_1(past_key, k)
        transed_k = transpose_without_shape(concated_k)

        concated_v = concat_in_axis_1(past_value, v)
        transed_v = transpose_without_shape(concated_v)

        scaled_q = paddle.incubate.passes.ir.PassDesc.OP.scale(X=transed_q)

        q_mul_k = paddle.incubate.passes.ir.PassDesc.OP.matmul_v2(X=scaled_q, Y=transed_k)
        q_mul_k.SetAttr("trans_y", True)

        scaled_q_mul_k = paddle.incubate.passes.ir.PassDesc.OP.scale(X=q_mul_k)

        added_attn_weight = paddle.incubate.passes.ir.PassDesc.OP.elementwise_add(X=scaled_q_mul_k, Y=attn_mask)

        softmax_attn_weight = paddle.incubate.passes.ir.PassDesc.OP.softmax(X=added_attn_weight)

        dropout_op = paddle.incubate.passes.ir.PassDesc.OP.dropout
        dropout_op.SetAttr("dropout_implementation", "upscale_in_train")
        dropout_op._outputs.pop("Mask")
        softmax_attn_weight = dropout_op(X=softmax_attn_weight)

        out = paddle.incubate.passes.ir.PassDesc.OP.matmul_v2(X=softmax_attn_weight, Y=transed_v)

        out = transpose_without_shape(out)

        out = reshape_without_shape(out)

        # linear
        out = linear_with_params(out, self_out_linear_weight, self_out_linear_bias)

        dropout_op_2 = paddle.incubate.passes.ir.PassDesc.OP.dropout
        dropout_op_2.SetAttr("dropout_implementation", "upscale_in_train")
        dropout_op_2._outputs.pop("Mask")
        out = dropout_op_2(X=out)

        #resadd
        res_add_out = paddle.incubate.passes.ir.PassDesc.OP.elementwise_add(X=embeddings, Y=out)

        layer_out = layernorm_only_out(res_add_out, self_out_norm_weight, self_out_norm_bias)

        linear_out = linear_with_params(layer_out, ffn_linear_weight, ffn_linear_bias)

        gelu_out = paddle.incubate.passes.ir.PassDesc.OP.gelu(X=linear_out)

        linear_2_out = linear_with_params(gelu_out, ffn_out_linear_weight, ffn_out_linear_bias)

        dropout_op_3 = paddle.incubate.passes.ir.PassDesc.OP.dropout
        dropout_op_3.SetAttr("dropout_implementation", "upscale_in_train")
        dropout_op_3._outputs.pop("Mask")
        drop_3_out = dropout_op_3(X=linear_2_out)

        # # residule
        res_add_out_2 = paddle.incubate.passes.ir.PassDesc.OP.elementwise_add(X=res_add_out, Y=drop_3_out)

        return res_add_out_2, concated_k, concated_v

    def gpt3_layer_cache_adaptor(embeddings, norm_weight, norm_bias, mix_linear_weight, mix_linear_bias,
      self_out_linear_weight, self_out_linear_bias, self_out_norm_weight, self_out_norm_bias,
      ffn_linear_weight, ffn_linear_bias, ffn_out_linear_weight, ffn_out_linear_bias, attn_mask, past_key, past_value):
        gpt3_layer_cache_op = paddle.incubate.passes.ir.PassDesc.OP.gpt3_layer
        gpt3_layer_cache_op._outputs = {}
        gpt3_layer_cache_op(
            Hidden=embeddings,
            NormWeight=norm_weight,
            NormBias=norm_bias,
            MixLinearWeight=mix_linear_weight,
            MixLinearBias=mix_linear_bias,
            SelfOutLinearWeight=self_out_linear_weight,
            SelfOutLinearBias=self_out_linear_bias,
            SelfOutNormWeight=self_out_norm_weight,
            SelfOutNormBias=self_out_norm_bias,
            FfnLinearWeight=ffn_linear_weight,
            FfnLinearBias=ffn_linear_bias,
            FfnOutLinearWeight=ffn_out_linear_weight,
            FfnOutLinearBias=ffn_out_linear_bias,
            AttentionMask=attn_mask,
            PastKey=past_key,
            PastValue=past_value)

        outs_name = [paddle.fluid.unique_name.generate('gpt3_layer') for i in range(3)] # 3 outputs
        print(outs_name)
        gpt3_layer_cache_op._desc.set_output("Out", [outs_name[0]])
        gpt3_layer_cache_op._desc.set_output("PresentKey", [outs_name[1]])
        gpt3_layer_cache_op._desc.set_output("PresentValue", [outs_name[2]])

        block = paddle.static.default_main_program().current_block()
        results = []
        for out in outs_name:
            results.append(block.create_var(name=out))
        return results[0], results[1], results[2]

    def replace(embeddings, norm_weight, norm_bias, mix_linear_weight, mix_linear_bias,
      self_out_linear_weight, self_out_linear_bias, self_out_norm_weight, self_out_norm_bias,
      ffn_linear_weight, ffn_linear_bias, ffn_out_linear_weight, ffn_out_linear_bias, attn_mask, past_key, past_value):

        out = gpt3_layer_cache_adaptor(embeddings, norm_weight, norm_bias, mix_linear_weight, mix_linear_bias,
            self_out_linear_weight, self_out_linear_bias, self_out_norm_weight, self_out_norm_bias,
            ffn_linear_weight, ffn_linear_bias, ffn_out_linear_weight, ffn_out_linear_bias, attn_mask, past_key, past_value)

        return out[0], out[1], out[2]

    return pattern, replace

因为while op里面的替换效果显示不出,使用while op之外的全量layer 替换效果代替:

kq0g1dla

kq0g1dla1#

建议先通过custom pass的方式绕过去,后续我们需要定位下是否是框架的问题并修复 @ronny1996

相关问题