请提出你的问题 Please ask your question
目前的conv2d 的kernel选择逻辑疑问:
看了下代码这里kernel_key中传入了俩输入:input和filter
当input为dtype=fp32 filter为dtype=fp16的时候:
选择的kernel是fp16的,是否符合预期?
import paddle
import paddle.nn.functional as F
import numpy as np
x_var = paddle.randn((2, 3, 8, 8), dtype='float32')
w_var = paddle.to_tensor(np.random.rand(6,3,3,3).astype('float16'))
y_var = F.conv2d(x_var, w_var)
3条答案
按热度按时间vktxenjb1#
@iclementine 请问能帮看下吗?
bcs8qyzn2#
当input为dtype=fp32 filter为dtype=fp16的时候: 选择的kernel是fp16的,是否符合预期?
正常,没有注册这种input为fp32,和filter=fp16这种实现,建议用同类型的输入dtype
zfciruhq3#
@LDOUBLEV 看上去目前的选择逻辑是根据最后一个参数的dtype选择的