在Lua Torch中,两个零矩阵的乘积具有nan项
我在Lua/torch中遇到了torch.mm函数的一个奇怪行为。下面是一个简单的程序来演示这个问题在Lua Torch中,两个零矩阵的乘积具有nan项,lua,torch,Lua,Torch,我在Lua/torch中遇到了torch.mm函数的一个奇怪行为。下面是一个简单的程序来演示这个问题 iteration = 0; a = torch.Tensor(2, 2); b = torch.Tensor(2, 2); prod = torch.Tensor(2,2); a:zero(); b:zero(); repeat prod = torch.mm(a,b); ent = prod[{2,1}]; iteration = iteration + 1; un
iteration = 0;
a = torch.Tensor(2, 2);
b = torch.Tensor(2, 2);
prod = torch.Tensor(2,2);
a:zero();
b:zero();
repeat
prod = torch.mm(a,b);
ent = prod[{2,1}];
iteration = iteration + 1;
until ent ~= ent
print ("error at iteration " .. iteration);
print (prod);
该程序由一个循环组成,其中程序将两个零2x2矩阵相乘,并测试乘积矩阵的entry是否等于nan。似乎程序应该永远运行,因为乘积应该始终等于0,因此ent应该为0。但是,程序会打印:
error at iteration 548
0.000000 0.000000
nan nan
[torch.DoubleTensor of size 2x2]
为什么会这样
更新:
可以找到为
torch.mm
自动生成Lua包装的代码部分
当您在循环中写入prod=torch.mm(a,b)
时,它在幕后对应于以下C代码(由于以下原因由此包装器生成):
因此:
- 将创建一个新的结果张量,并使用适当的尺寸调整其大小
- 但是这个新的张量没有初始化,即这里没有
或显式填充,因此它指向垃圾内存,可能包含NaN-scalloc
- 这个张量被推到堆栈上,以便在Lua端可用作返回值
prod
张量(即在循环中,prod
对初始值进行阴影处理)
另一方面,调用torch.mm(prod,a,b)
会使用初始的prod
张量来存储结果(在这种情况下,在幕后不需要创建专用的张量)。因为在您的代码片段中,您没有用给定的值初始化/填充它,所以它也可能包含垃圾
在这两种情况下,核心运算都是agemm
乘法,比如C=beta*C+alpha*a*B,beta=0,alpha=1。结果看起来是这样的:
real *a_ = a;
for(i = 0; i < m; i++)
{
real *b_ = b;
for(j = 0; j < n; j++)
{
real sum = 0;
for(l = 0; l < k; l++)
sum += a_[l*lda]*b_[l];
b_ += ldb;
/*
* WARNING: beta*c[j*ldc+i] could give NaN even if beta=0
* if the other operand c[j*ldc+i] is NaN!
*/
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
}
a_++;
}
real*a_uu=a;
对于(i=0;i
评论是我的
因此:
torch.mm(a,b)
:在每次迭代中,在不初始化的情况下创建一个新的结果张量(它可能包含NaN-s)因此每次迭代都有返回NaN-s的风险(参见上述警告)torch.mm(prod,a,b)
:由于未初始化prod
张量,因此存在相同的风险。但是:这种风险只存在于重复/直到循环的第一次迭代中,因为就在prod
之后,用0-s填充,并在后续迭代中重复使用THDoubleTensor_fill(arg1,0);
)
在案例2中:您应该首先初始化prod
,并使用torch.mm(prod,a,b)
构造来避免任何NaN问题
--
编辑:这个问题现在已经解决了(请参见此)。如果您能更多地了解您的环境,如硬件、操作系统、SHA和解释器(LuaJIT 2.0?2.1?或Lua 5.1?5.2?)。以下是我的环境信息:Linux CentOS、LuaJIT 2.0.3、Torch 7。Torch是在没有LAPACK和BLAS库的情况下编译的。在我用OpenBLAS重新编译之后,问题就消失了。但我仍然对它的原因感兴趣。谢谢你的回答和修复!
real *a_ = a;
for(i = 0; i < m; i++)
{
real *b_ = b;
for(j = 0; j < n; j++)
{
real sum = 0;
for(l = 0; l < k; l++)
sum += a_[l*lda]*b_[l];
b_ += ldb;
/*
* WARNING: beta*c[j*ldc+i] could give NaN even if beta=0
* if the other operand c[j*ldc+i] is NaN!
*/
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
}
a_++;
}