rust 平方和的编译器优化[重复]

dm7nw8vv  于 2022-12-29  发布在  其他
关注(0)|答案(3)|浏览(218)
    • 此问题在此处已有答案**:

How does clang generate non-looping code for sum of squares?(2个答案)
昨天关门了。
这里有一些我觉得有趣的事情:

pub fn sum_of_squares(n: i32) -> i32 {
    let mut sum = 0;
    for i in 1..n+1 {
        sum += i*i;
    }
    sum
}

这是Rust中平方和的简单实现。这是rustc 1.65.0-O3的汇编代码

lea     ecx, [rdi + 1]
        xor     eax, eax
        cmp     ecx, 2
        jl      .LBB0_2
        lea     eax, [rdi - 1]
        lea     ecx, [rdi - 2]
        imul    rcx, rax
        lea     eax, [rdi - 3]
        imul    rax, rcx
        shr     rax
        imul    eax, eax, 1431655766
        shr     rcx
        lea     ecx, [rcx + 4*rcx]
        add     ecx, eax
        lea     eax, [rcx + 4*rdi]
        add     eax, -3
.LBB0_2:
        ret

我原以为它会使用平方和的公式,但它没有,它使用了一个神奇的数字1431655766,我一点也不懂。
然后我想看看clang和gcc在C++中对同一个函数做了什么

test    edi, edi
        jle     .L8
        lea     eax, [rdi-1]
        cmp     eax, 17
        jbe     .L9
        mov     edx, edi
        movdqa  xmm3, XMMWORD PTR .LC0[rip]
        xor     eax, eax
        pxor    xmm1, xmm1
        movdqa  xmm4, XMMWORD PTR .LC1[rip]
        shr     edx, 2
.L4:
        movdqa  xmm0, xmm3
        add     eax, 1
        paddd   xmm3, xmm4
        movdqa  xmm2, xmm0
        pmuludq xmm2, xmm0
        psrlq   xmm0, 32
        pmuludq xmm0, xmm0
        pshufd  xmm2, xmm2, 8
        pshufd  xmm0, xmm0, 8
        punpckldq       xmm2, xmm0
        paddd   xmm1, xmm2
        cmp     eax, edx
        jne     .L4
        movdqa  xmm0, xmm1
        mov     eax, edi
        psrldq  xmm0, 8
        and     eax, -4
        paddd   xmm1, xmm0
        add     eax, 1
        movdqa  xmm0, xmm1
        psrldq  xmm0, 4
        paddd   xmm1, xmm0
        movd    edx, xmm1
        test    dil, 3
        je      .L1
.L7:
        mov     ecx, eax
        imul    ecx, eax
        add     eax, 1
        add     edx, ecx
        cmp     edi, eax
        jge     .L7
.L1:
        mov     eax, edx
        ret
.L8:
        xor     edx, edx
        mov     eax, edx
        ret
.L9:
        mov     eax, 1
        xor     edx, edx
        jmp     .L7
.LC0:
        .long   1
        .long   2
        .long   3
        .long   4
.LC1:
        .long   4
        .long   4
        .long   4
        .long   4

这是gcc 12.2-O3,GCC也没有使用平方和公式,我也不知道为什么要检查数字是否大于17,但是由于某种原因,gcc确实比clang和rustc做了很多运算。
这是clang 15.0.0加上-O3

test    edi, edi
    jle     .LBB0_1
    lea     eax, [rdi - 1]
    lea     ecx, [rdi - 2]
    imul    rcx, rax
    lea     eax, [rdi - 3]
    imul    rax, rcx
    shr     rax
    imul    eax, eax, 1431655766
    shr     rcx
    lea     ecx, [rcx + 4*rcx]
    add     ecx, eax
    lea     eax, [rcx + 4*rdi]
    add     eax, -3
    ret
.LBB0_1:
        xor     eax, eax
        ret

我真的不明白clang在做什么样的优化,但是rustc、clang和gcc不喜欢n(n+1)(2n+1)/6
然后我计算了它们的性能。Rust的表现明显好于gcc和clang。这是100次执行的平均结果。使用第11代英特尔酷睿i7 - 11800h@2.30 GHz

Rust: 0.2 microseconds
Clang: 3 microseconds
gcc: 5 microseconds

有人能解释一下性能差异吗?

    • 编辑**C++:
int sum_of_squares(int n){
    int sum = 0;
    for(int i = 1; i <= n; i++){
        sum += i*i;
    }
    return sum;
}
    • EDIT 2**对于所有想知道我的基准测试代码的人:
use std::time::Instant;
pub fn sum_of_squares(n: i32) -> i32 {
    let mut sum = 0;
    for i in 1..n+1 {
        sum += i*i;
    }
    sum
}

fn main() {
    let start = Instant::now();
    let result = sum_of_squares(1000);
    let elapsed = start.elapsed();

    println!("Result: {}", result);
    println!("Elapsed time: {:?}", elapsed);
}

在C++中:

#include <chrono>
#include <iostream>

int sum_of_squares(int n){
    int sum = 0;
    for(int i = 1; i <= n; i++){
        sum += i*i;
    }
    return sum;
}

int main() {
    auto start = std::chrono::high_resolution_clock::now();
    int result = sum_of_squares(1000);
    auto end = std::chrono::high_resolution_clock::now();

    std::cout << "Result: " << result << std::endl;
    std::cout << "Elapsed time: "
              << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count()
              << " microseconds" << std::endl;

    return 0;
}
tyg4sfes

tyg4sfes1#

我原以为它会使用平方和的公式,但它没有,它使用了一个神奇的数字1431655766,我一点也不懂。
LLVM确实将该循环转换为公式,但它不同于简单的平方和公式。
This文章比我更好地解释了公式和生成的代码。

a14dhokn

a14dhokn2#

在一台128位乘积的机器上,除以常数c通常是通过乘以2^64 / c来实现的,这就是你奇怪的常数的来源。
现在公式n(n+1)(2n+1)/ 6在n较大时会溢出,而和不会溢出,所以这个公式只能非常非常小心地使用。

vcudknz3

vcudknz33#

Clang在C中使用-O3做了同样的优化,但GCC还没有。参见on GodBolt.AFAIK,默认的Rust编译器内部使用LLVM,就像Clang一样。这就是为什么它们会产生类似的代码。GCC使用一个简单的循环,使用SIMD指令向量化,而Clang使用一个公式,就像你在问题中给出的公式一样。
C
代码中的优化汇编代码如下所示:

sum_of_squares(int):                    # @sum_of_squares(int)
        test    edi, edi
        jle     .LBB0_1
        lea     eax, [rdi - 1]
        lea     ecx, [rdi - 2]
        imul    rcx, rax
        lea     eax, [rdi - 3]
        imul    rax, rcx
        shr     rax
        imul    eax, eax, 1431655766
        shr     rcx
        lea     ecx, [rcx + 4*rcx]
        add     ecx, eax
        lea     eax, [rcx + 4*rdi]
        add     eax, -3
        ret
.LBB0_1:
        xor     eax, eax
        ret

这个优化主要来自于IndVarSimplify optimization pass,我们可以看到一些变量是32位编码的,而另一些是33位编码的(在主流平台上需要64位寄存器),代码基本上是这样的:

if(edi == 0)
    return 0;
eax = rdi - 1;
ecx = rdi - 2;
rcx *= rax;
eax = rdi - 3;
rax *= rcx;
rax >>= 1;
eax *= 1431655766;
rcx >>= 1;
ecx = rcx + 4*rcx;
ecx += eax;
eax = rcx + 4*rdi;
eax -= 3;
return eax;

这可以进一步简化为以下等效的C++代码:

if(n == 0)
    return 0;
int64_t m = n;
int64_t tmp = ((m - 3) * (m - 1) * (m - 2)) / 2;
tmp = int32_t(int32_t(tmp) * int32_t(1431655766));
return 5 * ((m - 1) * (m - 2) / 2) + tmp + (4*m - 3);

注意,为了清楚起见,忽略了一些强制转换和溢出。
神奇的数字1431655766来自于对与3相除相关的溢出的一种校正。实际上,1431655766 / 2**32 ~= 0.33333333348855376.Clang利用32位溢出来生成公式n(n+1)(2n+1)/6的快速实现。

相关问题