assembly 为每个int8_t元素添加两个饱和向量(uint64_t类型)

svgewumm  于 2023-04-30  发布在  其他
关注(0)|答案(4)|浏览(134)

我最近遇到了一个问题:
向量中有8个元素,每个元素由int8_t表示。
在x86_64中实现一个将添加两个向量(uint64_t类型)的算法。
添加元素时应考虑饱和度算法。
例如:
80 + 60 = 127
(−40)+(−100)= −128
最大的挑战是所施加的限制:

*除ret外无条件指令;没有跳转、cmove、set等。
*解决方案长度不能超过48条指令(存在短于37条指令的解决方案)

我想不出有什么解决方案能满足这些限制。谁能给予我一些提示吗?C中的例子是受欢迎的。

我只能使用“标准”、传输、算术、逻辑指令和标准寄存器:
*mov cbw/cwde/cdqe cwd/cdq/cqo movzx movsx
*add sub imul mul idiv div inc dec neg
*and or xor not sar sarx shr shrx shl shlx ror rol
*莱亚ret

o2rvlv0m

o2rvlv0m1#

这里是一个版本(测试,不需要imul),需要22个指令时compiled with clang-16

uint64_t add(uint64_t x, uint64_t y) {
    uint64_t eq, xv, yv, satmask, satbits, satadd, t0, t1;
    uint64_t signmask = 0x8080808080808080ULL;

    eq = (x ^ ~y) & signmask;
    xv = x & ~signmask;
    yv = y & ~signmask;
    xv += yv;
    satbits = (xv ^ y) & eq;
    satadd = satbits >> 7;
    satmask = (satbits << 1) - satadd;
    xv ^= eq;
    t0 = (xv & ~satmask) ^ signmask;
    t1 = satadd & ~(xv >> 7);
    return t0 - t1;
}

组装:

mov     rdx, rsi
xor     rdx, rdi
not     rdx
movabs  r8, -9187201950435737472
and     rdx, r8
movabs  rcx, 9187201950435737471
and     rdi, rcx
and     rcx, rsi
add     rcx, rdi
xor     rsi, rcx
and     rsi, rdx
lea     rax, [rsi + rsi]
shr     rsi, 7
xor     rcx, rdx
not     rax
add     rax, rsi
and     rax, rcx
xor     rax, r8
shr     rcx, 7
not     rcx
and     rcx, rsi
sub     rax, rcx
6za6bjd0

6za6bjd02#

我用C++写的是这样的:

#include <cstdint>

uint64_t add(uint64_t a, uint64_t b) {
    uint64_t asigns = a & 0x8080808080808080L;
    uint64_t bsigns = b & 0x8080808080808080L;
    uint64_t sum = (a^asigns) + (b^bsigns);
    // fix up 8 bit wrapped sums
    sum ^= asigns ^ bsigns;
    uint64_t sumsigns = sum & 0x8080808080808080L;
    // we saturate high when a and b were positive, but the result is negative
    uint64_t sat = sumsigns & ~(asigns|bsigns);
    sum |= (sat>>7)*127;
    sum &= ~sat;
    // we saturate negative when a and b were negative, but the result is positive
    sat = (asigns&bsigns) & ~sumsigns;
    sum &= ~((sat>>7)*127);
    sum |= sat;
    return sum;
}

然后我转到https://godbolt.org/,看看各种编译器生成了什么。clang-16给出了33条指令:

add(unsigned long, unsigned long):
        movabs  rdx, -9187201950435737472
        mov     rax, rdi
        and     rax, rdx
        mov     rcx, rsi
        and     rcx, rdx
        movabs  r8, 9187201950435737471
        mov     r9, rdi
        and     r9, r8
        and     r8, rsi
        add     r8, r9
        xor     rax, rcx
        xor     rax, r8
        or      rsi, rdi
        not     rsi
        and     rdx, rsi
        and     rdx, r8
        mov     rsi, rdx
        shr     rsi, 7
        mov     r8, rdx
        sub     r8, rsi
        or      r8, rax
        xor     r8, rdx
        not     rax
        and     rcx, rdi
        and     rcx, rax
        mov     rdx, rcx
        shr     rdx, 7
        mov     rax, rcx
        sub     rax, rdx
        not     rax
        and     rax, r8
        or      rax, rcx
        ret

您可以尝试其他各种选项。

ktecyv1j

ktecyv1j3#

使用paddsb指令添加带符号饱和度的字节向量。实现可以像这样(假设amd64 sysv abi):

movq    %rdi, %mm0  # move the first operand to an MMX register
    movq    %rsi, %mm1  # move the second operand to an MMX register
    paddsb  %mm1, %mm0  # packed add bytes with signed saturation
    movq    %mm0, %rax  # move the result back to a scalar register
    emms                # end MMX mode
    ret                 # return to caller

在没有MMX的情况下,可以使用以下方法。其思想是使用SWAR技术对所有字节并行执行以下算法:

int8_t addsb(int8_t a, int8_t b) {
    int8_t q = a + b;

    /* can the addition overflow (are a and b of different sign?) */
    if (((a ^ b) & 0x80) == 0) {
        /* is the result of different sign? */
        if (((a ^ q) & 0x80) != 0) {
            /* if yes, overflow occurred */
            return (a & 0x80 ? 0x80 : 0x7f);
        }
    }

    return (q);
}

以下代码未经测试,但应该可以工作:

paddsb: mov     $0x0101010101010101, %rdx       # LSB bit masks
        lea     (%rsi, %rdi, 1), %rax           # q = a + b
        mov     %rdi, %rcx
        xor     %rsi, %rcx                      # a ^ b
        mov     %rax, %rbx
        sub     %rcx, %rbx                      # a + b - (a ^ b) (carry out)
        and     %rdx, %rbx                      # carry outs from one byte to the next
        not     %rcx                            # ~a ^ b
        xor     %rax, %rdi                      # a ^ q
        sub     %rbx, %rax                      # compensate for the carry out
        and     %rcx, %rdi                      # bit 7 set where overflow
        shr     $7, %rdi                        # bit 0 set where overflow
        and     %rdx, %rdi                      # 0x01 where overflow, 0x00 where not
        imul    $0xff, %rdi, %rdi               # 0xff where overflow, 0x00 where not
        shr     $7, %rsi
        and     %rdx, %rsi                      # 0x01 where b negative, 0x00 where not
        mov     $0x7f7f7f7f7f7f7f7f, %rdx
        add     %rsi, %rdx                      # 0x80 where b negative, 0x7f where not
        and     %rdi, %rdx                      # masked to only where overflown
        not     %rdi                            # 0x00 where overflow, 0xff where not
        and     %rdi, %rax                      # q masked to only where not overflown
        or      %rdx, %rax                      # signed sum of a and b
        ret

注意,需要一些额外的处理来避免从一个字节执行到下一个字节。

tvz2xvvm

tvz2xvvm4#

下面的代码使用了一种普通的带符号饱和的字节加法方法,但在指令计数和执行时间方面与Falk Hüffner's excellent algorithm非常有竞争力。
为了避免跨越字节通道边界,模拟SIMD算法的经典方法是对低阶7位和最高有效位分别执行计算,然后合并部分结果。在这种情况下,这也有助于检测有符号整数溢出,其一个定义是最高有效位的进位与该位的进位不同。
有符号整数加法溢出只能在加数的符号相同时发生。如果发生溢出,字节大小的特殊结果(下面代码中的spc)是0x7f0x80,因此可以从任一加数的符号计算。
溢出标志扩展为全0或全1的全字节掩码,用于选择常规加法结果(下面代码中的res)或传统复用习惯中的特殊溢出结果。
问题列出了允许的BMI 2指令集扩展(2013年引入)的各种指令,所以我假设使用BMI 1扩展的andn指令也是允许的,尽管问题中没有明确列出。
我在Windows 10机器上开发了我的实现epaddsb,因此代码使用了Windows对x86-64的调用约定。对于Linux使用的System V ABI,更改这一点很简单:只需交换几个注册名。为了与Falk Hüffner的算法进行比较,我使用最新的Intel oneAPI编译器编译了他的C代码,并在hpaddsb中捕获了生成的代码。
epaddsb需要21条指令,而不需要ret,而hpaddsb需要20条指令,而不需要ret。在我的基于Skylake CPU的PC上,两种变体的性能在±2%的测量噪声水平内是相同的。

PUBLIC  epaddsb

_TEXT   SEGMENT
  
        ALIGN 16

;; epaddsb(a,b): emulated byte-wise 64-bit addition with signed saturation
;;
;; Windows x86-64 calling convention:
;; function arguments: rcx, rdx, {r8, r9}
;; function return value: rax
;; scratch registers: rax, rcx, rdx, r8, r9, {r10, r11}

epaddsb PROC
        mov  rax, 7f7f7f7f7f7f7f7fh ; NMSB_MASK = ~MSB_MASK
        mov  r8, rcx                ; a
        mov  r9, rdx                ; b
        and  rcx, rax               ; a & NMSB_MASK
        and  rdx, rax               ; b & NMSB_MASK
        xor  r9, r8                 ; sum = a ^ b
        add  rdx, rcx               ; res = (a & NMSB_MASK) + (b & NMSB_MASK)
        andn rcx, rax, r8           ; a & ~NMSB_MASK
        xor  r8, rdx                ; res ^ a
        shr  rcx, 7                 ; (a & ~NMSB_MASK) >> 7
        andn r8, r9, r8             ; ofl = (res ^ a) & ~sum
        add  rcx, rax               ; spc = ((a & ~ NMSB_MASK) >> 7) + NMSB_MASK
        andn r9, rax, r9            ; sum & ~NSMB_MASK
        xor  rdx, r9                ; res = res ^ (sum & ~NMSB_MASK)
        andn r8, rax, r8            ; ofl & ~NMSB_MASK
        lea  r9, [r8 + r8]          ; ofl << 1
        shr  r8, 7                  ; ofl >> 7
        sub  r9, r8                 ; mask = (ofl << 1) - (ofl >> 7)
        andn rax, r9, rdx           ; res & ~mask
        and  rcx, r9                ; spc & mask
        or   rax, rcx               ; res = (spc & mask) | (res & ~mask)
        ret
epaddsb ENDP

        ALIGN 16

;; Falk Hüffner's algorithm from https://stackoverflow.com/a/76090715/780717
;; Compiled by Intel(R) oneAPI DPC++/C++ compiler version 2023.0.0

hpaddsb PROC
        mov  rax, rdx              ;
        xor  rax, rcx              ;
        mov  r8, 8080808080808080h ;
        andn r9, rax, r8           ;
        mov  r10, 7f7f7f7f7f7f7f7fh;
        and  rcx, r10              ;
        and  r10, rdx              ;
        add  r10, rcx              ;
        xor  rdx, r10              ;
        and  rdx, r9               ;
        lea  rax, [rdx + rdx]      ;
        shr  rdx, 7                ;
        xor  r10, r9               ;
        not  rax                   ;
        add  rax, rdx              ;
        and  rax, r10              ;
        xor  rax, r8               ;
        shr  r10, 7                ;
        andn rcx, r10, rdx         ;
        sub  rax, rcx              ;
    ret                     
hpaddsb ENDP

        ALIGN 16

_TEXT   ENDS

        END

我在下面展示我的测试脚手架。我构建如下

ml64 /c paddsb.obj paddsb.asm
icx /W4 /Ox /QxHOST paddsb_stackoverflow.c paddsb.obj

使用Microsoft Macro Assembler 14。27.29112.0和英特尔oneAPI DPC++/C++编译器2023。0.0.

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>

#define NBR_TEST_CASES     (1000000000)
#define TEST_HUEFFNER_ALGO (0)

/* emulated byte-wise 64-bit addition with signed saturation; in assembly */
extern uint64_t epaddsb (uint64_t a, uint64_t b); /* algorithm: N. Juffa */
extern uint64_t hpaddsb (uint64_t a, uint64_t b); /* algorithm: F. Hüffner */

/* reference function for byte-wise addition with signed saturation */
uint64_t paddsb_ref (uint64_t a, uint64_t b)
{
    int8_t a0 = (int8_t)(uint8_t)(a >>  0);
    int8_t a1 = (int8_t)(uint8_t)(a >>  8);
    int8_t a2 = (int8_t)(uint8_t)(a >> 16);
    int8_t a3 = (int8_t)(uint8_t)(a >> 24);
    int8_t a4 = (int8_t)(uint8_t)(a >> 32);
    int8_t a5 = (int8_t)(uint8_t)(a >> 40);
    int8_t a6 = (int8_t)(uint8_t)(a >> 48);
    int8_t a7 = (int8_t)(uint8_t)(a >> 56);
    int8_t b0 = (int8_t)(uint8_t)(b >>  0);
    int8_t b1 = (int8_t)(uint8_t)(b >>  8);
    int8_t b2 = (int8_t)(uint8_t)(b >> 16);
    int8_t b3 = (int8_t)(uint8_t)(b >> 24);
    int8_t b4 = (int8_t)(uint8_t)(b >> 32);
    int8_t b5 = (int8_t)(uint8_t)(b >> 40);
    int8_t b6 = (int8_t)(uint8_t)(b >> 48);
    int8_t b7 = (int8_t)(uint8_t)(b >> 56);
    b0 = ((a0 + b0) > 127) ? 127 : (((a0 + b0) < (-128)) ? (-128) : (a0 + b0));
    b1 = ((a1 + b1) > 127) ? 127 : (((a1 + b1) < (-128)) ? (-128) : (a1 + b1));
    b2 = ((a2 + b2) > 127) ? 127 : (((a2 + b2) < (-128)) ? (-128) : (a2 + b2));
    b3 = ((a3 + b3) > 127) ? 127 : (((a3 + b3) < (-128)) ? (-128) : (a3 + b3));
    b4 = ((a4 + b4) > 127) ? 127 : (((a4 + b4) < (-128)) ? (-128) : (a4 + b4));
    b5 = ((a5 + b5) > 127) ? 127 : (((a5 + b5) < (-128)) ? (-128) : (a5 + b5));
    b6 = ((a6 + b6) > 127) ? 127 : (((a6 + b6) < (-128)) ? (-128) : (a6 + b6));
    b7 = ((a7 + b7) > 127) ? 127 : (((a7 + b7) < (-128)) ? (-128) : (a7 + b7));
    return (((uint64_t)(uint8_t)b0 <<  0) | ((uint64_t)(uint8_t)b1 <<  8) | 
            ((uint64_t)(uint8_t)b2 << 16) | ((uint64_t)(uint8_t)b3 << 24) |
            ((uint64_t)(uint8_t)b4 << 32) | ((uint64_t)(uint8_t)b5 << 40) | 
            ((uint64_t)(uint8_t)b6 << 48) | ((uint64_t)(uint8_t)b7 << 56)); 
}

/*  https://groups.google.com/forum/#!original/comp.lang.c/qFv18ql_WlU/IK8KGZZFJx4J */
static uint64_t kiss64_x = 1234567890987654321ULL;
static uint64_t kiss64_c = 123456123456123456ULL;
static uint64_t kiss64_y = 362436362436362436ULL;
static uint64_t kiss64_z = 1066149217761810ULL;
static uint64_t kiss64_t;
#define MWC64  (kiss64_t = (kiss64_x << 58) + kiss64_c, \
                kiss64_c = (kiss64_x >> 6), kiss64_x += kiss64_t, \
                kiss64_c += (kiss64_x < kiss64_t), kiss64_x)
#define XSH64  (kiss64_y ^= (kiss64_y << 13), kiss64_y ^= (kiss64_y >> 17), \
                kiss64_y ^= (kiss64_y << 43))
#define CNG64  (kiss64_z = 6906969069ULL * kiss64_z + 1234567ULL)
#define KISS64 (MWC64 + XSH64 + CNG64)

int main (void)
{
    uint64_t res, ref, a, b, count = 0;

    printf ("Testing %s's algo\n", TEST_HUEFFNER_ALGO ? "Hueffner" : "Juffa");
    do {
        a = KISS64;
        b = KISS64;
        ref = paddsb_ref (a, b);
#if TEST_HUEFFNER_ALGO
        res = hpaddsb (a, b);
#else // TEST_HUEFFNER_ALGO
        res = epaddsb (a, b);
#endif // TEST_HUEFFNER_ALGO
        if (res != ref) {
            printf ("error @ a=%016llx b=%016llx:  res=%016llx  ref=%016llx\n", 
                    a, b, res, ref);
            return EXIT_FAILURE;
        }
        count++;
    } while (count < NBR_TEST_CASES);
    printf ("test passed\n");
    return EXIT_SUCCESS;
}

相关问题