🐛 Bug
重现问题
重现问题的步骤:
- 我使用op conv2d构建了一个模型。
- 一个计算图是:permute dim--> conv2d-->layernorm。
- 在编译过程中,我遇到了以下问题。
我认为这个问题是由于在dl.gpu.Matmul()之后融合了permute和conv操作符,导致缓冲区形状和index_map形状不匹配。
1、错误日志tvm.error.InternalError: Traceback (most recent call last): 4: operator() at /workspace/tvm-unity/src/tir/schedule/schedule.cc:287 3: tvm::tir::TracedScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/traced_schedule.cc:678 2: tvm::tir::ConcreteScheduleNode::TransformLayout(tvm::tir::BlockRV const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/concrete_schedule.cc:993 1: tvm::tir::TransformLayout(tvm::tir::ScheduleState, tvm::tir::StmtSRef const&, int, tvm::tir::BufferIndexType, tvm::tir::IndexMap const&, tvm::runtime::Optional<tvm::tir::IndexMap> const&, bool) at /workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc:1160 0: tvm::tir::LegalizeIndexMapDType(tvm::tir::IndexMap const&, tvm::runtime::Array<tvm::PrimExpr, void> const&) at /workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc:1106 File "/workspace/tvm-unity/src/tir/schedule/primitive/layout_transformation.cc", line 1106 InternalError: Check failed: (args.size() == initial_indices_orig.size()) is false:
2、其他信息
1). T.index_map(lambda i0, i1, i2, i3, i4, i5: (T.int64(0), i1 * T.int64(64) + i2, i3)) ??? is not match?
2). with T.block("conv2d_nchw", no_realize=True): v_nn = T.axis.spatial(T.int64(1)) v_ff = T.axis.spatial(T.int64(256)) v_yy = T.axis.spatial(T.int64(64)) v_xx = T.axis.spatial(T.int64(64)) v_rc = T.axis.reduce(T.int64(768)) v_ry = T.axis.reduce(T.int64(1)) v_rx = T.axis.reduce(T.int64(1)) pad_temp = T.Buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16") B = T.Buffer((T.int64(256), T.int64(768), T.int64(1), T.int64(1)), "float16") T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], B[v_ff, v_rc, v_ry, v_rx]) conv2d_nchw = T.Buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float16") T.writes(conv2d_nchw[v_nn, v_ff, v_yy, v_xx]) with T.init(): conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = T.float16(0) conv2d_nchw[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * B[v_ff, v_rc, v_ry, v_rx]
3). @T.prim_func(private=True) def main(permute_dims161: T.Buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16"), vision_tower_vision_tower_high_neck_0_weight1: T.Buffer((T.int64(256), T.int64(768), T.int64(1), T.int64(1)), "float16"), compute_intermediate: T.Buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): pad_temp = T.alloc_buffer((T.int64(1), T.int64(768), T.int64(64), T.int64(64)), "float16") conv2d_nchw_intermediate = T.alloc_buffer((T.int64(1), T.int64(256), T.int64(64), T.int64(64)), "float16") for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(768), T.int64(64), T.int64(64)): with T.block("pad_temp"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(permute_dims161[v_i0, v_i1, v_i2, v_i3]) T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3]) pad_temp[v_i0, v_i1, v_i2, v_i3] = permute_dims161[v_i0, v_i1, v_i2, v_i3] for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(1), T.int64(256), T.int64(64), T.int64(64), T.int64(768), T.int64(1), T.int64(1)): with T.block("conv2d_nchw"): v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx]) T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], vision_tower_vision_tower_high_neck_0_weight1[v_ff, v_rc, v_ry, v_rx]) T.writes(conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx]) with T.init(): conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = T.float16(0) conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * vision_tower_vision_tower_high_neck_0_weight1[v_ff, v_rc, v_ry, v_rx] for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(256), T.int64(64), T.int64(64)): with T.block("compute"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(conv2d_nchw_intermediate[v_i0, v_i1, v_i2, v_i3]) T.writes(compute_intermediate[v_i0, v_i1, v_i2, v_i3]) compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", conv2d_nchw_intermediate[v_i0, v_i1, v_i2, v_i3])
T.index_map(lambda i0, i1, i2, i3, i4, i5: (T.int64(0), i1 * T.int64(64) + i2, i3))
4条答案
按热度按时间0kjbasz61#
感谢您的报告,如果能够获得最小重现,那将是有帮助的。
您可以在转换之前导出TVMScript,最小化它并运行您提到的转换。
r1zhe5dt2#
我遇到了相同的问题,你解决了吗?
ijxebb2r3#
感谢报告,如果能得到最小重现,那将很有帮助。您可以在转换之前导出TVMScript,最小化它并运行您提到的转换。
嗨,这个bug可以像这样重现:
只有在kernel_size等于1时,
dl.gpu.Matmul
在tvm.relax.transform.LegalizeOps()
之后才会报告错误。在MLLM模型中,可能涉及到图像嵌入,并且此操作(conv2d中的kernel_size等于1)可能会被使用。am46iovg4#
你好,@senlyu163。看起来在应用dlight到conv2d的内核大小为1时,这是一个已知的问题。这个问题的出现是因为reindex调度在expr上进行了简化。为了解决这个问题,我之前创建了一个草稿PR。你可以合并相关的更改并修改
dlight
的normalize_to_matmul
函数。检出这个草稿PR:apache/tvm#16440
与此问题相关的关键组件是向
cache_reindex
添加一个skip_simplify
标志。你可以按照以下方式应用相关的更改: