C++ AVX2根据条件将连续元素扩展为稀疏向量?(如AVX512 VPD)

C++ AVX2根据条件将连续元素扩展为稀疏向量?(如AVX512 VPD),c++,intrinsics,avx2,C++,Intrinsics,Avx2,有人知道如何矢量化以下代码吗 uint32_t r[8]; uint16_t* ptr; for (int j = 0; j < 8; ++j) if (r[j] < C) r[j] = *(ptr++); uint32\u t r[8]; uint16_t*ptr; 对于(int j=0;j

有人知道如何矢量化以下代码吗

uint32_t r[8];
uint16_t* ptr;
for (int j = 0; j < 8; ++j)
    if (r[j] < C)
        r[j] = *(ptr++);
uint32\u t r[8];
uint16_t*ptr;
对于(int j=0;j<8;++j)
if(r[j]

这基本上是一个蒙面收集行动。自动矢量器无法处理此问题。如果ptr是uint32_t*,则应可通过直接实现。但即便如此,如何生成正确的索引向量呢?仅仅使用压缩加载并洗牌结果(需要类似的索引向量)不是更快吗?

更新的答案:主要代码段已被重写为函数和解决方案 已添加适用于AMD处理器的

正如Peter Cordes在评论中提到的,AVX-512指令
vpexpandd
在这里很有用。 下面的函数
\u mm256\u mask\u expand\u epi32\u AVX2\u BMI()
\u mm256\u mask\u expand\u epi32\u AVX2()
至少要模仿这个指令。AVX2_BMI变型适用于英特尔Haswell处理器和更新的处理器。
\u mm256\u mask\u expand\u epi32\u AVX2()
功能适用于速度较慢或速度较慢的AMD处理器 缺少
pdep
指令,如Ryzen处理器。 在这个函数中,有几个高吞吐量的指令, 例如移位和简单的算术运算,用来代替
pdep
指令。 AMD处理器的另一种可能性是 一次只处理4个元素,并使用一个很小的(16个元素)查找表来检索 shuf_面具

在这两个函数下面,将展示如何使用它们对标量代码进行矢量化

答案使用了与Peter Cordes的答案相似的想法, 其中讨论了基于面具的左包装。在这个答案中,BMI2 指令
pext
用于计算置换向量。 这里我们使用
pdep
指令来计算置换向量。 函数
\u mm256\u mask\u expand\u epi32\u AVX2()
查找置换向量 以不同的方式通过计算
请注意,英特尔Haswell处理器或更新版本支持
pdep
指令。AMD Zen/Ryzen CPU也支持它,但该体系结构的吞吐量和延迟数非常差:都是18个周期。有趣的是,在比较之前,您需要一个
vpmovzx
,这样您可以从
vpmovmskb
中获得每个元素4位,而不是2位。我觉得应该有办法只使用一个
pdep
。但我想,如果把4位索引解包成8位索引,在其他地方可能需要额外的指令。我看不到用整数乘法直接替换
pdep
的方法。如果专门针对AMD进行优化,您可能会完全避免pdep,或者使用
vpmovsxwq
(使用16位元素进行比较后)在
vpmovzkb
之前解包比较到qword@PeterCordes是,但请注意,
vpmovzx
的结果不会用作比较的输入。因此,不幸的是,掩码是8x32位而不是8x16位。事实上,可以通过与我的文章中所述相同的方式去除一个
pdep
。Christoph对该评论的评论以及随后的讨论也与此相关。啊,对了,没有有意义的变量名很容易混淆。我忘了哪个东西更窄。我想在扩展到256b向量之前,我们可以在16位元素上使用更窄的洗牌。但是为
vpshufb
构建控制向量比较困难,在AVX512之前,没有16位粒度的字洗牌和变量控制。(如果AVX512F没有
vpexpandd
,您可以在这里使用
vpermw
,将索引展开成256b向量,同时执行
vpmovzxwd
工作)。无论如何,对于AMD来说,避免使用
vpermd-ymm
是值得考虑的,因为
vpermd
是多个UOP。@PeterCordes对于AMD来说,最好一次只处理4个元素,并使用一个小的(16个元素)查找表来检索
shuf\u掩码
。这避免了昂贵的
pdep
vpermd-ymm
指令。如果您有AVX512F,您希望使用合并掩码(而不是
{z}
零掩码)根据比较掩码扩展连续源数据@wim的答案是使用
pdep
,因为AVX2没有展开/压缩向量指令。如果您可以将
r[]
的元素归零,而不是让它们保持不变,则效率会更高。重新命名以更好地匹配实际问题(和答案)。
/*     gcc -O3 -m64 -Wall -mavx2 -march=broadwell mask_expand_avx.c     */
#include <immintrin.h>
#include <stdio.h>
#include <stdint.h>

__m256i _mm256_mask_expand_epi32_AVX2_BMI(__m256i src, __m256i mask, __m256i insert_vals, int* nonz){ 
    /* Scatter the insert_vals to the positions indicated by mask.                                                                    */               
    /* Blend the src with these scattered insert_vals.                                                                                */
    /* Return also the number of nonzeros in mask (which is inexpensive here                                                          */
    /* because _mm256_movemask_epi8(mask) has to be computed anyway.)                                                                          */
    /* This code is suitable for Intel Haswell and newer processors.                                                                  */
    /* This code is less suitble for AMD Ryzen processors, due to the                                                                 */
    /* slow pdep instruction on those processors, see _mm256_mask_expand_epi32_AVX2                                                   */
    uint32_t all_indx         = 0x76543210;
    uint32_t mask_int32       = _mm256_movemask_epi8(mask);                           /* Packed mask of 8 nibbles                     */
    uint32_t wanted_indx      = _pdep_u32(all_indx, mask_int32);                      /* Select the right nibbles from all_indx       */
    uint64_t expand_indx      = _pdep_u64(wanted_indx, 0x0F0F0F0F0F0F0F0F);           /* Expand the nibbles to bytes                  */
    __m128i  shuf_mask_8bit   = _mm_cvtsi64_si128(expand_indx);                       /* Move to AVX-128 register                     */
    __m256i  shuf_mask        = _mm256_cvtepu8_epi32(shuf_mask_8bit);                 /* Expand bytes to 32-bit integers              */
    __m256i  insert_vals_exp  = _mm256_permutevar8x32_epi32(insert_vals, shuf_mask);  /* Expand insert_vals to the right positions    */
    __m256i  dst              = _mm256_blendv_epi8(src, insert_vals_exp, mask);       /* src is replaced by insert_vals_exp at the postions indicated by mask */
             *nonz            = _mm_popcnt_u32(mask_int32) >> 2;
             return dst;
}


__m256i _mm256_mask_expand_epi32_AVX2(__m256i src, __m256i mask, __m256i insert_vals, int* nonz){ 
    /* Scatter the insert_vals to the positions indicated by mask.                                                                    */               
    /* Blend the src with these scattered insert_vals.                                                                                */
    /* Return also the number of nonzeros in mask.                                                                                    */
    /* This code is an alternative for the _mm256_mask_expand_epi32_AVX2_BMI function.                                                */
    /* In contrast to that code, this code doesn't use the BMI instruction pdep.                                                      */
    /* Therefore, this code is suitable for AMD processors.                                                                            */
    __m128i  mask_lo          = _mm256_castsi256_si128(mask);                      
    __m128i  mask_hi          = _mm256_extracti128_si256(mask, 1);                  
    __m128i  mask_hi_lo       = _mm_packs_epi32(mask_lo, mask_hi);                    /* Compressed 128-bits (8 x 16-bits) mask       */
             *nonz            = _mm_popcnt_u32(_mm_movemask_epi8(mask_hi_lo)) >> 1;
    __m128i  prefix_sum       = mask_hi_lo;
    __m128i  prefix_sum_shft  = _mm_slli_si128(prefix_sum, 2);                        /* The permutation vector is based on the       */
             prefix_sum       = _mm_add_epi16(prefix_sum, prefix_sum_shft);           /* Prefix sum of the mask.                      */
             prefix_sum_shft  = _mm_slli_si128(prefix_sum, 4);
             prefix_sum       = _mm_add_epi16(prefix_sum, prefix_sum_shft);
             prefix_sum_shft  = _mm_slli_si128(prefix_sum, 8);
             prefix_sum       = _mm_add_epi16(prefix_sum, prefix_sum_shft);
    __m128i  shuf_mask_16bit  = _mm_sub_epi16(_mm_set1_epi16(-1), prefix_sum);
    __m256i  shuf_mask        = _mm256_cvtepu16_epi32(shuf_mask_16bit);               /* Expand 16-bit integers to 32-bit integers    */
    __m256i  insert_vals_exp  = _mm256_permutevar8x32_epi32(insert_vals, shuf_mask);  /* Expand insert_vals to the right positions    */
    __m256i  dst              = _mm256_blendv_epi8(src, insert_vals_exp, mask);       /* src is replaced by insert_vals_exp at the postions indicated by mask */
             return dst;
}


/* Unsigned integer compare _mm256_cmplt_epu32 doesn't exist                                                    */
/* The next two lines are based on Paul R's answer https://stackoverflow.com/a/32945715/2439725                 */
#define _mm256_cmpge_epu32(a, b) _mm256_cmpeq_epi32(_mm256_max_epu32(a, b), a)
#define _mm256_cmplt_epu32(a, b) _mm256_xor_si256(_mm256_cmpge_epu32(a, b), _mm256_set1_epi32(-1))

int print_input(uint32_t* r, uint32_t C, uint16_t* ptr);
int print_output(uint32_t* r, uint16_t* ptr);

int main(){
    int       nonz;
    uint32_t  r[8]        = {6, 3, 1001, 2, 1002, 7, 5, 1003};
    uint32_t  r_new[8];
    uint32_t  C           = 9;
    uint16_t* ptr         = malloc(8*2);  /* allocate 16 bytes for 8 uint16_t's */
              ptr[0] = 11; ptr[1] = 12; ptr[2] = 13;ptr[3] = 14; ptr[4] = 15; ptr[5] = 16; ptr[6] = 17; ptr[7] = 18;
    uint16_t* ptr_new;

              printf("Test values:\n");
              print_input(r,C,ptr);

    __m256i   src         = _mm256_loadu_si256((__m256i *)r);
    __m128i   ins         = _mm_loadu_si128((__m128i *)ptr);
    __m256i   insert_vals = _mm256_cvtepu16_epi32(ins);
    __m256i   mask_C      = _mm256_cmplt_epu32(src,_mm256_set1_epi32(C));   


              printf("Output _mm256_mask_expand_epi32_AVX2_BMI:\n");
    __m256i   output      = _mm256_mask_expand_epi32_AVX2_BMI(src, mask_C, insert_vals, &nonz);
                            _mm256_storeu_si256((__m256i *)r_new,output);
              ptr_new     = ptr + nonz;
              print_output(r_new,ptr_new);              


              printf("Output _mm256_mask_expand_epi32_AVX2:\n");
              output      = _mm256_mask_expand_epi32_AVX2(src, mask_C, insert_vals, &nonz);
                            _mm256_storeu_si256((__m256i *)r_new,output);
              ptr_new     = ptr + nonz;
              print_output(r_new,ptr_new);              


              printf("Output scalar loop:\n");
              for (int j = 0; j < 8; ++j)
                  if (r[j] < C)
                      r[j] = *(ptr++);
              print_output(r,ptr);              

              return 0;
}

int print_input(uint32_t* r, uint32_t C, uint16_t* ptr){
    printf("r[0]..r[7]        =     %4u  %4u  %4u  %4u  %4u  %4u  %4u  %4u  \n",r[0],r[1],r[2],r[3],r[4],r[5],r[6],r[7]);
    printf("Threshold value C =     %4u  %4u  %4u  %4u  %4u  %4u  %4u  %4u  \n",C,C,C,C,C,C,C,C);
    printf("ptr[0]..ptr[7]    =     %4hu  %4hu  %4hu  %4hu  %4hu  %4hu  %4hu  %4hu  \n\n",ptr[0],ptr[1],ptr[2],ptr[3],ptr[4],ptr[5],ptr[6],ptr[7]);
    return 0;
}

int print_output(uint32_t* r, uint16_t* ptr){
    printf("r[0]..r[7]        =     %4u  %4u  %4u  %4u  %4u  %4u  %4u  %4u  \n",r[0],r[1],r[2],r[3],r[4],r[5],r[6],r[7]);
    printf("ptr               = %p \n\n",ptr);
    return 0;
}
$ ./a.out
Test values:
r[0]..r[7]        =        6     3  1001     2  1002     7     5  1003  
Threshold value C =        9     9     9     9     9     9     9     9  
ptr[0]..ptr[7]    =       11    12    13    14    15    16    17    18  

Output _mm256_mask_expand_epi32_AVX2_BMI:
r[0]..r[7]        =       11    12  1001    13  1002    14    15  1003  
ptr               = 0x92c01a 

Output _mm256_mask_expand_epi32_AVX2:
r[0]..r[7]        =       11    12  1001    13  1002    14    15  1003  
ptr               = 0x92c01a 

Output scalar loop:
r[0]..r[7]        =       11    12  1001    13  1002    14    15  1003  
ptr               = 0x92c01a