ARM-NEON中水平布尔约简的优化

ARM-NEON中水平布尔约简的优化,arm,simd,neon,Arm,Simd,Neon,我正在试验一个跨平台的SIMD库ala,其中一部分是提供一些“水平”SIMD操作。特别是,库提供的API包括any()->bool和all()->bool函数,其中是T类型的K元素的向量,boolN是N位布尔值,即所有的1或所有的0,作为SSE和NEON返回进行比较操作 例如,假设v是一个(一个128位向量),它可能是某个函数的结果。我想计算all(v)=v[0]&&v[1]&&v[2]&&v[3]和any(v)=v[0]| | v[1]| v[2]| | v[3] 使用SSE很容易,例如,将提

我正在试验一个跨平台的SIMD库ala,其中一部分是提供一些“水平”SIMD操作。特别是,库提供的API包括
any()->bool
all()->bool
函数,其中
T
类型的
K
元素的向量,
boolN
N
位布尔值,即所有的1或所有的0,作为SSE和NEON返回进行比较操作

例如,假设
v
是一个
(一个128位向量),它可能是某个函数的结果。我想计算
all(v)=v[0]&&v[1]&&v[2]&&v[3]
any(v)=v[0]| | v[1]| v[2]| | v[3]

使用SSE很容易,例如,将提取每个元素的高位,因此上述类型的
all
变为(使用C intrinsic):

(对于后者,我们也可以做类似的事情,速度可能更快,但基本上是相同的想法。)

有没有其他技巧可以用来实现这一点


是的,我知道使用SIMD向量单元时,水平操作不是很好。但有时它是有用的,例如,mandlebrot的许多SIMD实现会同时在4个点上运行,当所有点都超出范围时,就会跳出内部循环。。。这需要做一个比较,然后做一个横向和横向比较。

注意:今天第一次看arm时,我可能错了

UPD:删除了ARM-V7,并将在单独的答案中记录我们最终所做的事情

ARM-V8。

对于ARM-V8,请看一下glibc的strlen实现:

ARM-V8引入了跨寄存器的缩减。在这里,他们使用min与0进行比较

        uminv        datab2, datav.16b
        mov          tmp1, datav2.d[0]
        cbnz         tmp1, L(main_loop)
找到最小的字符,与0比较-取下16个字节

ARM-V8中还有一些其他的减少,如
vaddvq_u8

我敢肯定,你可以通过
movemask
等类似工具完成大部分你想做的事情

这里另一件有趣的事情是他们如何发现
first\u true

        /* Set te NULL byte as 0xff and the rest as 0x00, move the data into a
           pair of scalars and then compute the length from the earliest NULL
           byte.  */
        cmeq        datav.16b, datav.16b, #0
        mov        data1, datav.d[0]
        mov        data2, datav.d[1]
        cmp        data1, 0
        csel        data1, data1, data2, ne
        sub        len, src, srcin
        rev        data1, data1
        add        tmp2, len, 8
        clz        tmp1, data1
        csel        len, len, tmp2, ne
        add        len, len, tmp1, lsr 3
看起来有点吓人,但我的理解是:

  • 他们只需执行if/else(如果前半部分没有零,则后半部分有),就可以将其缩小到64位数字
  • 使用count前导零来查找位置(我不太了解这里的所有endianness内容,但它是libc-因此这是正确的一个)

  • 因此,如果您只需要V8,那么有一个解决方案。

    这是我目前在中实现的解决方案

    如果您的后端支持C++20,您可以直接使用该库:它有arm-v7、arm-v8(目前只有少量endian)以及从sse2到avx-512的所有x86的实现。它是开源的,并获得麻省理工学院的许可。目前处于测试阶段。如果您正在试用该库,请随时联系(例如遇到问题)

    对所有事情都要一概而论-我还没有设置手臂基准点

    注意:除了基本的all和any之外,我们还有一个
    movemask
    等价物,用于执行更复杂的操作,如
    first\u true
    。这不是问题的一部分,也不令人惊讶,但可以找到代码

    ARM-V7,8字节寄存器

    现在,arm-v7是32位的体系结构,所以我们尽量使用32位元素

    • 任何
    使用成对32位最大值。如果任何元素为真,则最大值为真

    // cast to dwords
    dwords = vpmax_u32(dwords, dwords);
    return vget_lane_u32(dwords, 0);
    
    • 全部
    成对的最小值而不是最大值。也就是你测试的变化。 如果您有4字节的元素-只需测试为真。如果短或字符-您需要测试为-1

    // cast to dwords
    dwords = vpmin_u32(dwords, dwords);
    std::uint32_t combined = vget_lane_u32(dwords, 0);
    
    // Assuming T is your scalar type
    if constexpr ( sizeof(T) >= 4 ) return combined;
    
    // I decided that !~ is better than -1, compiler will figure it out.
    return !~combined; 
    
    ARM-V7,16字节寄存器

    对于大于字符的任何内容,只需将其转换为64位字符即可。以下是转换列表

    对于chars,我发现最好的方法是重新解释为uint32并进行额外检查。 因此,比较所有的==-1和任何>0。 在两个8字节寄存器中拆分似乎更好

    然后只需对dword寄存器执行所有/任何操作

    ARM-v8,8字节

    ARM-v8有64位的支持,所以你可以得到一个64位的通道,这个通道是可以测试的

    ARM-v8,16字节

    我们使用
    vmaxvq_u32
    ,因为对于
    any
    vminvq_u32
    vminvq_u16
    vminvq_u8
    所有
    all
    都没有64位。 (类似于)

    结论

    缺乏基准肯定会让我担心,有些指令有时会有问题,我对此一无所知。
    不管怎样,至少到目前为止,这是我所掌握的最好的方法。

    移动MSKPS更有趣的SSE指令是
    ptest
    。你可以将它用于
    或者
    或者
    或者
    。我认为Neon有相同的指令
    vtest
    。我还没有实现它,但是我想你可以在这里找到答案。@Zboson:
    >令人遗憾的是,vtst在这里并不是特别有用(因为您已经从比较中得到了一个0/-1值的向量)。Nils的建议来自链接答案(饱和加法+读取Q位)通常情况下效果不好,因为Q位是粘性的,所以需要先用RMW清除它。所以通常的方法是在arm32上使用多个
    vpmax
    /
    vpmin
    ,在arm64上使用一个
    umaxv
    /
    uminv
    。我不知道有多少“mandlebrot的SIMD实现将一次在4个点上运行,当所有点都超出范围时,将退出内部循环…”我自己已经做了一段时间了(实际上是8个像素,AVX用于单浮点).对于x86,我使用了
    ptest
    ,但您似乎已经找到了使用ARM的最佳解决方案:即使用arm7两次最小/最大值,使用arm8一次。@StephenCanon,在这种情况下,您可能可以提供答案。相关:要求使用
    movmskps
    等效值
    // cast to dwords
    dwords = vpmax_u32(dwords, dwords);
    return vget_lane_u32(dwords, 0);
    
    // cast to dwords
    dwords = vpmin_u32(dwords, dwords);
    std::uint32_t combined = vget_lane_u32(dwords, 0);
    
    // Assuming T is your scalar type
    if constexpr ( sizeof(T) >= 4 ) return combined;
    
    // I decided that !~ is better than -1, compiler will figure it out.
    return !~combined;