用bsxfun加速Matlab嵌套for循环
我有一个图用bsxfun加速Matlab嵌套for循环,matlab,optimization,vectorization,bsxfun,Matlab,Optimization,Vectorization,Bsxfun,我有一个图nxngraphW描述为它的邻接矩阵和每个节点的组标签(整数)的n向量 我需要为每对组计算组c中的节点和组d中的节点之间的链接(边)数量。为此,我为循环编写了一个嵌套的,但我确信这不是计算我称之为mcd的矩阵的最快方法,即计算组c和d之间的边数的矩阵。 是否可以通过bsxfun使此操作更快 function mcd = interlinks(W,ci) %// W is the adjacency matrix of a simple undirected graph %// ci a
nxn
graphW
描述为它的邻接矩阵和每个节点的组标签(整数)的n
向量
我需要为每对组计算组c
中的节点和组d
中的节点之间的链接(边)数量。为此,我为循环编写了一个嵌套的,但我确信这不是计算我称之为mcd
的矩阵的最快方法,即计算组c
和d
之间的边数的矩阵。
是否可以通过bsxfun
使此操作更快
function mcd = interlinks(W,ci)
%// W is the adjacency matrix of a simple undirected graph
%// ci are the group labels of every node in the graph, can be from 1 to |C|
n = length(W); %// number of nodes in the graph
m = sum(nonzeros(triu(W))); %// number of edges in the graph
ncomms = length(unique(ci)); %// number of groups of ci
mcd = zeros(ncomms); %// this is the matrix that counts the number of edges between group c and group d, twice the number of it if c==d
for c=1:ncomms
nodesc = find(ci==c); %// nodes in group c
for d=1:ncomms
nodesd = find(ci==d); %// nodes in group d
M = W(nodesc,nodesd); %// submatrix of edges between c and d
mcd(c,d) = sum(sum(M)); %// count of edges between c and d
end
end
%// Divide diagonal half because counted twice
mcd(1:ncomms+1:ncomms*ncomms)=mcd(1:ncomms+1:ncomms*ncomms)/2;
例如,在这里的图片中,邻接矩阵是
W=[0 1 1 0 0 0;
1 0 1 1 0 0;
1 1 0 0 1 1;
0 1 0 0 1 0;
0 0 1 1 0 1;
0 0 1 0 1 0];
组标签向量为ci=[1 1 2 3]
,结果矩阵为mcd
:
mcd=[3 2 1;
2 1 1;
1 1 0];
例如,它意味着组1自身有3个链接,组2有2个链接,组3有1个链接
IIUC并假设ci
是一个排序数组,看起来基本上是在进行分块求和,但块大小不规则。因此,您可以使用一种方法,沿着行和列,然后在ci
中的移位位置进行微分,这将基本上为您提供分块求和
实现如下所示-
%// Get cumulative sums row-wise and column-wise
csums = cumsum(cumsum(W,1),2)
%/ Get IDs of shifts and thus get cumsums at those positions
[~,idx] = unique(ci) %// OR find(diff([ci numel(ci)]))
csums_indexed = csums(idx,idx)
%// Get the blockwise summations by differentiations on csums at shifts
col1 = diff(csums_indexed(:,1),[],1)
row1 = diff(csums_indexed(1,:),[],2)
rest2D = diff(diff(csums_indexed,[],2),[],1)
out = [[csums_indexed(1,1) ; col1] [row1 ; rest2D]]
如果您不反对mex函数,可以使用下面的代码
测试代码
测试结果
使用这些设置的测试结果:
base avg fun time = 4.94275
mex avg fun time = 0.0373092
bsx avg fun time = 0.126406
norm(x1 - x2) = 0
norm(x1 - x3) = 0
基本上,对于较小的n_标签
,bsx函数做得很好,但您可以将其设置得足够大,以便mex函数更快
C++代码
将它放入像interlink\u mex.cpp
这样的文件中,并使用mex interlink\u mex.cpp
进行编译。你需要一个C++编译器在你的机器等…< /P>
#include "mex.h"
#include "matrix.h"
#include <math.h>
// Author: Matthew Gunn
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
if(nrhs != 2)
mexErrMsgTxt("Invalid number of inputs. Shoudl be 2 input argument.");
if(nlhs != 1)
mexErrMsgTxt("Invalid number of outputs. Should be 1 output arguments.");
if(!mxIsLogical(prhs[0])) {
mexErrMsgTxt("First argument should be a logical array (i.e. type logical)");
}
if(!mxIsDouble(prhs[1])) {
mexErrMsgTxt("Second argument should be an array of type double");
}
const mxArray *W = prhs[0];
const mxArray *ci = prhs[1];
size_t W_m = mxGetM(W);
size_t W_n = mxGetN(W);
if(W_m != W_n)
mexErrMsgTxt("Rows and columns of W are not equal");
// size_t ci_m = mxGetM(ci);
size_t ci_n = mxGetNumberOfElements(ci);
mxLogical *W_data = mxGetLogicals(W);
// double *W_data = mxGetPr(W);
double *ci_data = mxGetPr(ci);
size_t *ci_data_size_t = (size_t*) mxCalloc(ci_n, sizeof(size_t));
size_t ncomms = 0;
double intpart;
for(size_t i = 0; i < ci_n; i++) {
double x = ci_data[i];
if(x < 1 || x > 65536 || modf(x, &intpart) != 0.0) {
mexErrMsgTxt("Input ci is not all integers from 1 to a maximum value of 65536 (can edit source code to change this)");
}
size_t xx = (size_t) x;
if(xx > ncomms)
ncomms = xx;
ci_data_size_t[i] = xx - 1;
}
mxArray *mcd = mxCreateDoubleMatrix(ncomms, ncomms, mxREAL);
double *mcd_data = mxGetPr(mcd);
for(size_t i = 0; i < W_n; i++) {
size_t ii = ci_data_size_t[i];
for(size_t j = 0; j < W_n; j++) {
size_t jj = ci_data_size_t[j];
mcd_data[ii + jj * ncomms] += (W_data[i + j * W_m] != 0);
}
}
for(size_t i = 0; i < ncomms * ncomms; i+= ncomms + 1) //go along diagonal
mcd_data[i]/=2; //divide by 2
mxFree(ci_data_size_t);
plhs[0] = mcd;
}
#包括“mex.h”
#包括“矩阵h”
#包括
//作者:马修·冈恩
void MEX函数(int nlhs、mxArray*plhs[]、int nrhs、const mxArray*prhs[]){
如果(nrhs!=2)
MEXERMSGSTXT(“输入数无效。应为2个输入参数”);
如果(nlhs!=1)
mexErrMsgTxt(“输出数无效。应为1个输出参数”);
如果(!mxiLogical(prhs[0])){
mexErrMsgTxt(“第一个参数应该是逻辑数组(即逻辑类型)”;
}
如果(!mxIsDouble(prhs[1])){
mexErrMsgTxt(“第二个参数应该是double类型的数组”);
}
常量mxArray*W=prhs[0];
常量mxArray*ci=prhs[1];
尺寸W_m=mxGetM(W);
尺寸W n=mxGetN(W);
如果(W_m!=W_n)
MEXERMSGSTXT(“W的行和列不相等”);
//大小ci=mxGetM(ci);
大小ci=mxGetNumberOfElements(ci);
mxLogical*W_data=mxGetLogicals(W);
//双*W_数据=mxGetPr(W);
double*ci_data=mxGetPr(ci);
size_t*ci_data_size_t=(size_t*)mxCalloc(ci_n,sizeof(size_t));
大小\u t ncomms=0;
双内点;
对于(大小i=0;i65536 | | modf(x,&intpart)!=0.0){
mexErrMsgTxt(“输入ci不是从1到最大值65536的所有整数(可以编辑源代码以更改此值)”;
}
尺寸xx=(尺寸)x;
如果(xx>ncomms)
ncomms=xx;
ci_data_size_t[i]=xx-1;
}
mxArray*mcd=mxCreateDoubleMatrix(NCOMM、NCOMM、mxREAL);
双*mcd_数据=mxGetPr(mcd);
对于(大小i=0;i
这个怎么样
C = bsxfun(@eq, ci,unique(ci)');
mcd = C*W*C'
mcd(logical(eye(size(mcd)))) = mcd(logical(eye(size(mcd))))./2;
我想这就是你想要的。你能进一步解释一下3乘3矩阵与你的图像的关系吗?您的示例中的[1 1 2 3]
是您的W
?如果是,创建该mcd
属性需要什么ci
?n有多大?W稀疏吗?ci是否包含从1到ncomms的整数值?在我的示例中,矩阵W是所示图形的邻接矩阵,ci
是图形顶点的成员索引。n
有多大并不重要,如果W
是稀疏的,我只想避免双重嵌套循环。是的,ci
包含从1到ncomms
的整数,表示i-th
顶点的组刚刚编辑了问题,以显示完整的邻接矩阵,如图所示。这正是我想要的,这个实现非常快,工作非常好,但我不理解它背后的逻辑。太好了!,这很简单,我会尽力解释的。第一个矩阵C是由1和0组成的NCOMMX长度(ci)矩阵。其中每行包含一个1,其中ci等于一个唯一值。现在让我们假设你想要计算dsm(1,2):你可以取第一个,你可以乘以W,从左边第一行,从右边第二行。这基本上是将W中的所有位置相加(乘以1或0,然后相加),其中C(1,:)等于1的行和C(2,:)等于1的列。通过做矩阵乘法,你可以一次完成所有的运算。我强烈建议你尝试一个简单的矩阵乘法例子,了解它是如何工作的!而且非常快。(对于大量标签,我的函数仍然更快:P)顺便说一句,我的方法中的大部分时间是生成矩阵C,特别是唯一函数。矩阵乘法非常快。我很想听到改进这一部分的方法……顺便说一句,我注意到如果length(unique(ci))
不等于max(ci)
,我的代码与您的代码并不完全相同,也就是说,我设置了ncomms=max(ci)。这就是为什么我在测试代码中添加了惟一调用,以使用IC变量作为标签id。
C = bsxfun(@eq, ci,unique(ci)');
mcd = C*W*C'
mcd(logical(eye(size(mcd)))) = mcd(logical(eye(size(mcd))))./2;