X86 用AVX实现指数函数的最快速度

X86 用AVX实现指数函数的最快速度,x86,simd,avx,exponential,avx2,X86,Simd,Avx,Exponential,Avx2,我正在寻找一种高效(快速)的指数函数近似值,该函数在AVX元素(单精度浮点)上运行。即不带SVML的-\uuuuuM256\uMM256\uExp\uPS(\uuuuM256 x) 相对精度应为~1e-6或~20尾数位(2^20中的1部分) 如果它是用英特尔内部函数以C风格编写的,我会很高兴的。 代码应该是可移植的(Windows、macOS、Linux、MSVC、ICC、GCC等) 这与类似,但该问题要求非常快且精度较低(目前的答案给出的精度约为1e-3) 另外,这个问题是寻找AVX/AVX

我正在寻找一种高效(快速)的指数函数近似值,该函数在AVX元素(单精度浮点)上运行。即不带SVML的-
\uuuuuM256\uMM256\uExp\uPS(\uuuuM256 x)

相对精度应为~1e-6或~20尾数位(2^20中的1部分)

如果它是用英特尔内部函数以C风格编写的,我会很高兴的。
代码应该是可移植的(Windows、macOS、Linux、MSVC、ICC、GCC等)


这与类似,但该问题要求非常快且精度较低(目前的答案给出的精度约为1e-3)

另外,这个问题是寻找AVX/AVX2(和FMA)。但请注意,这两个问题的答案很容易在SSE4
\uuuuM128
或AVX2
\uuuuuuM256
之间移植,因此未来读者应根据所需的精度/性能权衡进行选择。

您可以:

为此,您只需要AVX的加法和乘法运算。如果硬编码,然后乘以而不是除以,那么像1/2、1/6、1/24等系数会更快

根据精度要求,获取序列中的任意多个成员。请注意,您将得到相对误差:对于较小的
z
而言,其绝对值可能是
1e-6
,但是对于较大的
z
而言,其绝对值将大于
1e-6
,仍然
abs(E-E1)/abs(E)-1
小于
1e-6
(其中,
E
是精确指数,
E1
是近似值)


更新:正如@Peter Cordes在一篇评论中提到的,可以通过分离整数和小数部分的幂运算来提高精度,通过操纵二进制
float
表示的指数字段来处理整数部分(基于2^x,而不是e^x)。那么您的泰勒级数只需在小范围内最小化误差。

中的
exp
函数使用范围缩减和类似切比雪夫近似的多项式,与AVX指令并行计算8
exp
-s。使用正确的编译器设置确保
addps
mulp在可能的情况下,s
与FMA指令融合

将原始的
exp
代码从转换为可移植的(跨不同编译器)C/AVX2 intrinsics代码非常简单。原始代码使用gcc风格的对齐属性和巧妙的宏。修改后的代码使用标准的
\u mm256\u set1\u ps()
则位于小测试代码和表格下方。修改后的代码需要AVX2

以下代码用于简单测试:

intmain(){
int i;
浮动十五[8];
浮动yv[8];
__m256 x=_mm256_setr_ps(1.0f、2.0f、3.0f、4.0f、5.0f、6.0f、7.0f、8.0f);
__m256 y=exp256μps(x);
_mm256(xv,x);
_mm256_存储_ps(yv,y);

对于(i=0;i因为
exp()的快速计算)
需要对IEEE-754浮点操作数的指数字段进行操作,
AVX
实际上不适合此计算,因为它缺少整数运算。因此,我将重点介绍
AVX2
。对融合乘法加法的支持在技术上是与
AVX2
不同的一个功能,因此我提供了两种代码模式hs,使用和不使用FMA,由宏
usefma
控制

下面的代码将
exp()
计算到接近所需的精度10-6。FMA的使用在这里没有提供任何显著的改进,但在支持它的平台上,它应该提供性能优势

先前用于较低精度SSE实现的算法不能完全扩展到相当精确的实现,因为它包含一些数值特性较差的计算,但这在该上下文中并不重要。而不是使用[0,1]中的
f
或[-½,½]中的
f
计算ex=2i*2f,在较窄的区间[-½log2,½log2]内用
f
计算ex=2i*ef是有利的,其中
log
表示自然对数

为此,我们首先计算
i=rint(x*log2(e))
,然后计算
f=x-log(2)*i
。重要的是,后一种计算需要使用比本机精度更高的精度来传递精确的简化参数以传递给核心近似值。为此,我们使用Cody-Waite方案,首次发表于W.J.Cody&W.Waite,《基本功能软件手册》,Prentice Hall 1980。常数对数(2)分为较大量级的“高”部分和较小量级的“低”部分,其中包含“高”部分和数学常数之间的差异

选择尾数中有足够尾数零位的高部分,这样
i
与“高”部分的乘积就可以精确地表示为本机精度。在这里,我选择了一个有八个尾数零位的“高”部分,因为
i
肯定适合八位

本质上,我们计算f=x-i*log(2)high-i*log(2)low。这个简化的参数被传递到核心近似值,它是一个多项式,结果按2i缩放,如前一个答案所示

#包括
#定义使用\u FMA 0
/*计算[-87.33654f,88.72283]中x的exp(x)
最大相对误差:3.1575e-6(使用_-FMA=0);3.1533e-6(使用_-FMA=1)
*/
__m256更快更精确的运算速度avx2(m256 x)
{
__m256 t,f,p,r;
__m256i,j;
常数m256 l2e=_mm256_set1_ps(1.44269541f);/*log2(e)*/
常数m256 l2h=_mm256_set1_ps(-6.93145752e-1f);/*-log(2)_hi*/
常数m256 l2l=_mm256_set1_ps(-1.42860677e-6f);/*-log(2)_lo*/
/*[-log(2)/2,log(2)/2]中exp()的核心近似系数*/
常数m256 c0=\uMM256\uSET1\uPS(0.041944388f);
常数m256 c1=\uMM256\uSET1\uPS(0.168006673f);
常数
exp(z) = 1 + z + pow(z,2)/2 + pow(z,3)/6 + pow(z,4)/24 + ...
i = 0, x = 1.000000e+00, y = 2.718282e+00 
i = 1, x = 2.000000e+00, y = 7.389056e+00 
i = 2, x = 3.000000e+00, y = 2.008554e+01 
i = 3, x = 4.000000e+00, y = 5.459815e+01 
i = 4, x = 5.000000e+00, y = 1.484132e+02 
i = 5, x = 6.000000e+00, y = 4.034288e+02 
i = 6, x = 7.000000e+00, y = 1.096633e+03 
i = 7, x = 8.000000e+00, y = 2.980958e+03 
i      x                     y = exp256_ps(x)      double precision exp        relative error

i = 0  x = 1.000000000e+00   y = 2.718281746e+00   exp_dbl = 2.718281828e+00   rel_err =-3.036785947e-08
i = 1  x =-2.000000000e+00   y = 1.353352815e-01   exp_dbl = 1.353352832e-01   rel_err =-1.289636419e-08
i = 2  x = 3.000000000e+00   y = 2.008553696e+01   exp_dbl = 2.008553692e+01   rel_err = 1.672817689e-09
i = 3  x =-4.000000000e+00   y = 1.831563935e-02   exp_dbl = 1.831563889e-02   rel_err = 2.501162103e-08
i = 4  x = 5.000000000e+00   y = 1.484131622e+02   exp_dbl = 1.484131591e+02   rel_err = 2.108215155e-08
i = 5  x =-6.000000000e+00   y = 2.478752285e-03   exp_dbl = 2.478752177e-03   rel_err = 4.380257261e-08
i = 6  x = 7.000000000e+00   y = 1.096633179e+03   exp_dbl = 1.096633158e+03   rel_err = 1.849522682e-08
i = 7  x =-8.000000000e+00   y = 3.354626242e-04   exp_dbl = 3.354626279e-04   rel_err =-1.101575118e-08
float fast_exp(float x)
{
    const float c1 = 0.007972914726F;
    const float c2 = 0.1385283768F;
    const float c3 = 2.885390043F;
    const float c4 = 1.442695022F;      
    x *= c4; //convert to 2^(x)
    int intPart = (int)x;
    x -= intPart;
    float xx = x * x;
    float a = x + c1 * xx * x;
    float b = c3 + c2 * xx;
    float res = (b + a) / (b - a);
    reinterpret_cast<int &>(res) += intPart << 23; // res *= 2^(intPart)
    return res;
}
__m256 _mm256_exp_ps(__m256 _x)
{
    __m256 c1 = _mm256_set1_ps(0.007972914726F);
    __m256 c2 = _mm256_set1_ps(0.1385283768F);
    __m256 c3 = _mm256_set1_ps(2.885390043F);
    __m256 c4 = _mm256_set1_ps(1.442695022F);
    __m256 x = _mm256_mul_ps(_x, c4); //convert to 2^(x)
    __m256 intPartf = _mm256_round_ps(x, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC);
    x = _mm256_sub_ps(x, intPartf);
    __m256 xx = _mm256_mul_ps(x, x);
    __m256 a = _mm256_add_ps(x, _mm256_mul_ps(c1, _mm256_mul_ps(xx, x))); //can be improved with FMA
    __m256 b = _mm256_add_ps(c3, _mm256_mul_ps(c2, xx));
    __m256 res = _mm256_div_ps(_mm256_add_ps(b, a), _mm256_sub_ps(b, a));
    __m256i intPart = _mm256_cvtps_epi32(intPartf); //res = 2^intPart. Can be improved with AVX2!
    __m128i ii0 = _mm_slli_epi32(_mm256_castsi256_si128(intPart), 23);
    __m128i ii1 = _mm_slli_epi32(_mm256_extractf128_si256(intPart, 1), 23);     
    __m128i res_0 = _mm_add_epi32(ii0, _mm256_castsi256_si128(_mm256_castps_si256(res)));
    __m128i res_1 = _mm_add_epi32(ii1, _mm256_extractf128_si256(_mm256_castps_si256(res), 1));
    return _mm256_insertf128_ps(_mm256_castsi256_ps(_mm256_castsi128_si256(res_0)), _mm_castsi128_ps(res_1), 1);
}