Paddle 关于 Resnet 50 应用 pass优化 ,只有fuse_elewise_add_act_pass 起作用 ,其他的 pass 无效,无法得到全部的fuse op

yx2lnoni  于 2023-02-04  发布在  其他
关注(0)|答案(2)|浏览(226)

为使您的问题得到快速解决,在建立Issue前,请您先通过如下方式搜索是否有相似问题:【搜索issue关键字】【使用labels筛选】【官方文档】

paddle 2.2 版本
CPU 和GPU(cuda 11.2)两台机器

实验的目的的是想得到resnet 50 经过优化后的融合算子,为经优化的网络算子,打印如下
op type is conv2d
op type is batch_norm
op type is relu
op type is pool2d
op type is conv2d
op type is batch_norm
op type is relu
op type is conv2d
op type is batch_norm
op type is relu
op type is conv2d
op type is batch_norm
op type is conv2d
op type is batch_norm
op type is elementwise_add
op type is relu
op type is conv2d
。。。
。。。

采用的优化代码,应用pass 优化,代码来自 paddle 源码
Paddle-release-2.2\python\paddle\fluid\tests\unittests\test_apply_pass_to_program.py

代码中有两个测试case 只选择了第一个case进行修改 实验

import paddle
from paddle.vision.models import resnet50
from paddle.nn import CrossEntropyLoss
from paddle.fluid.framework import _apply_pass
from paddle.fluid.ir import apply_build_strategy
import paddle.fluid as fluid
import unittest
import numpy as np

def get_resnet50_model():
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
image = paddle.static.data(
name="image", shape=[None, 3, 224, 224], dtype="float32")
label = paddle.static.data(name="label", shape=[None, 1], dtype="int64")
model = resnet50()
loss_fn = CrossEntropyLoss()
pred = model(image)
loss = loss_fn(pred, label)
optimizer = paddle.optimizer.Adam(learning_rate=1e-3)
optimizer.minimize(loss)

return main, startup, image, label, loss

def global_block_contains_op(program, op_type):
for op in program.global_block().ops:
if op.type == op_type:
return True
return False

class TestApplyPassToProgram(unittest.TestCase):
def setUp(self):
paddle.enable_static()

def test_case(self):
    main, startup, image, label, loss = get_resnet50_model()
    fused_op = "fused_elemwise_add_activation"
    self.assertFalse(global_block_contains_op(main, fused_op))
    attrs = {
        "int_attr": -3,
        "size_t_attr": 10,
        "float_attr": 3.25,
        "float32_attr": -4.5,
        "str_attr": "any string attr value",
    }
    attr_types = {
        "size_t_attr": "size_t",
        "float32_attr": "float32",
    }
    ret_attrs = _apply_pass(main, startup, "fuse_elewise_add_act_pass",
                            attrs, attr_types)          #这里是选择采用的pass  个人理解
    self.assertEqual(attrs, ret_attrs)
    self.assertTrue(global_block_contains_op(main, fused_op))

if name == "main":
unittest.main()

#上述代码 只测试了 fuse_elewise_add_act_ 这一种pss 策略 打印其中 op
op type is conv2d
op type is batch_norm
op type is share_buffer
op type is relu
op type is pool2d
op type is conv2d
op type is batch_norm
op type is share_buffer
op type is relu
op type is conv2d
op type is batch_norm
op type is share_buffer
op type is relu
op type is conv2d
op type is batch_norm
op type is conv2d
op type is batch_norm
op type is fused_elemwise_add_activation #这里确实出现了应该的融合算子
。。。
。。。

但是当 替换掉 成其他的pass ,好像并不会起作用,不会产生新的fuse op
实验代码
class TestApplyPassToProgram(unittest.TestCase):
def setUp(self):
paddle.enable_static()

def test_case(self):
    main, startup, image, label, loss = get_resnet50_model()
    # fused_op = "fuse_bn_add_act"
    # self.assertFalse(global_block_contains_op(main, fused_op))
    attrs = {
        "int_attr": -3,
        "size_t_attr": 10,
        "float_attr": 3.25,
        "float32_attr": -4.5,
        "str_attr": "any string attr value",
    }
    attr_types = {
        "size_t_attr": "size_t",
        "float32_attr": "float32",
    }
    ret_attrs = _apply_pass(main, startup, "fuse_relu_depthwise_conv_pass",
                            attrs, attr_types)      #在这里  换成其他的 pass 优化策略    ,但是没起作用
    self.assertEqual(attrs, ret_attrs)

根据我个人的理解 应该可能 会有 conv bn relu 算子融合在一起,不应该只有 elemwwise_add_act 这样一个算子

我想知道我该如何修改代码 可以实现所有可以应用在 resnet 50 的 pass 优化
或者百度 可以提供一份 resnet 50 优化好的 fuse op 嘛
非常感谢

ct3nt3jp

ct3nt3jp1#

您好,我们已经收到了您的问题,会安排技术人员尽快解答您的问题,请耐心等待。请您再次检查是否提供了清晰的问题描述、复现代码、环境&版本、报错信息等。同时,您也可以通过查看 官网API文档常见问题历史IssueAI社区 来寻求解答。祝您生活愉快~

Hi! We've received your issue and please be patient to get responded. We will arrange technicians to answer your questions as soon as possible. Please make sure that you have posted enough message to demo your request. You may also check out the APIFAQGithub Issue and AI community to get the answer.Have a nice day!

4uqofj5v

4uqofj5v2#

你好,可以参考这里,获得最好的训练速度性能哈: https://github.com/PaddlePaddle/Perf/blob/master/ResNet50V1.5/README.md

相关问题