为什么Rust编译器不能自动向量化这个FP点产品实现?

eoxn13cs  于 2023-04-30  发布在  其他
关注(0)|答案(1)|浏览(122)

让我们考虑一个简单的约简,例如点积:

pub fn add(a:&[f32], b:&[f32]) -> f32 {
    a.iter().zip(b.iter()).fold(0.0, |c,(x,y)| c+x*y))
}

使用rustc 1.68与-C opt-level=3 -C target-feature=+avx2,+fma我得到

.LBB0_5:
        vmovss  xmm1, dword ptr [rdi + 4*rsi]
        vmulss  xmm1, xmm1, dword ptr [rdx + 4*rsi]
        vmovss  xmm2, dword ptr [rdi + 4*rsi + 4]
        vaddss  xmm0, xmm0, xmm1
        vmulss  xmm1, xmm2, dword ptr [rdx + 4*rsi + 4]
        vaddss  xmm0, xmm0, xmm1
        vmovss  xmm1, dword ptr [rdi + 4*rsi + 8]
        vmulss  xmm1, xmm1, dword ptr [rdx + 4*rsi + 8]
        vaddss  xmm0, xmm0, xmm1
        vmovss  xmm1, dword ptr [rdi + 4*rsi + 12]
        vmulss  xmm1, xmm1, dword ptr [rdx + 4*rsi + 12]
        lea     rax, [rsi + 4]
        vaddss  xmm0, xmm0, xmm1
        mov     rsi, rax
        cmp     rcx, rax
        jne     .LBB0_5

这是具有循环展开的标量实现,甚至不将穆尔+add收缩到FMA中。从这个代码到simd代码应该很容易,为什么rustc不优化这个?
如果我用i32替换f32,我会得到所需的自动矢量化:

.LBB0_5:
        vmovdqu ymm4, ymmword ptr [rdx + 4*rax]
        vmovdqu ymm5, ymmword ptr [rdx + 4*rax + 32]
        vmovdqu ymm6, ymmword ptr [rdx + 4*rax + 64]
        vmovdqu ymm7, ymmword ptr [rdx + 4*rax + 96]
        vpmulld ymm4, ymm4, ymmword ptr [rdi + 4*rax]
        vpaddd  ymm0, ymm4, ymm0
        vpmulld ymm4, ymm5, ymmword ptr [rdi + 4*rax + 32]
        vpaddd  ymm1, ymm4, ymm1
        vpmulld ymm4, ymm6, ymmword ptr [rdi + 4*rax + 64]
        vpmulld ymm5, ymm7, ymmword ptr [rdi + 4*rax + 96]
        vpaddd  ymm2, ymm4, ymm2
        vpaddd  ymm3, ymm5, ymm3
        add     rax, 32
        cmp     r8, rax
        jne     .LBB0_5
u5rb5r59

u5rb5r591#

这是因为浮点数不是关联的,通常意味着a+(b+c) != (a+b)+c。因此,对浮点数求和成为串行任务,因为编译器不会将((a+b)+c)+d重新排序为(a+b)+(c+d)。最后一个可以矢量化,第一个不能。
在大多数情况下,程序员并不关心求和顺序的差异。
GCC和clang提供-fassociative-math标志,其将允许编译器重新排序浮点运算以获得性能。
rustc不提供这一点,据我所知,llvm也不接受会改变这种行为的标志。
在夜间Rust中,你可以使用#![feature(core_intrinsics)]来获得优化:

#![feature(core_intrinsics)]
pub fn add(a:&[f32], b:&[f32]) -> f32 {
    unsafe {
        a.iter().zip(b.iter()).fold(0.0, |c,(x,y)| std::intrinsics::fadd_fast(c,x*y))
    }
}

这不使用fma。所以对于fma你必须用途:

#![feature(core_intrinsics)]
pub fn add(a:&[f32], b:&[f32]) -> f32 {
    unsafe {
        a.iter().zip(b.iter()).fold(0.0, |c,(&x,&y)| std::intrinsics::fadd_fast(c,std::intrinsics::fmul_fast(x,y)))
    }
}

我不知道有一个稳定的Rust解决方案,它不涉及显式的simd intrinsic。

相关问题