用bsxfun加速Matlab嵌套for循环

用bsxfun加速Matlab嵌套for循环,matlab,optimization,vectorization,bsxfun,Matlab,Optimization,Vectorization,Bsxfun,我有一个图nxngraphW描述为它的邻接矩阵和每个节点的组标签(整数)的n向量 我需要为每对组计算组c中的节点和组d中的节点之间的链接(边)数量。为此,我为循环编写了一个嵌套的,但我确信这不是计算我称之为mcd的矩阵的最快方法,即计算组c和d之间的边数的矩阵。 是否可以通过bsxfun使此操作更快 function mcd = interlinks(W,ci) %// W is the adjacency matrix of a simple undirected graph %// ci a

我有一个图
nxn
graph
W
描述为它的邻接矩阵和每个节点的组标签(整数)的
n
向量

我需要为每对组计算组
c
中的节点和组
d
中的节点之间的链接(边)数量。为此,我为循环编写了一个嵌套的
,但我确信这不是计算我称之为
mcd
的矩阵的最快方法,即计算组
c
d
之间的边数的矩阵。 是否可以通过
bsxfun
使此操作更快

function mcd = interlinks(W,ci)
%// W is the adjacency matrix of a simple undirected graph
%// ci are the group labels of every node in the graph, can be from 1 to |C|
n = length(W); %// number of nodes in the graph
m = sum(nonzeros(triu(W))); %// number of edges in the graph
ncomms = length(unique(ci)); %// number of groups of ci

mcd = zeros(ncomms); %// this is the matrix that counts the number of edges between group c and group d, twice the number of it if c==d

for c=1:ncomms
    nodesc = find(ci==c); %// nodes in group c
    for d=1:ncomms
        nodesd = find(ci==d); %// nodes in group d
        M = W(nodesc,nodesd); %// submatrix of edges between c and d
        mcd(c,d) = sum(sum(M)); %// count of edges between c and d
    end
end

%// Divide diagonal half because counted twice
mcd(1:ncomms+1:ncomms*ncomms)=mcd(1:ncomms+1:ncomms*ncomms)/2;
例如,在这里的图片中,邻接矩阵是

W=[0 1 1 0 0 0;
   1 0 1 1 0 0;
   1 1 0 0 1 1;
   0 1 0 0 1 0;
   0 0 1 1 0 1;
   0 0 1 0 1 0];
组标签向量为
ci=[1 1 2 3]
,结果矩阵为
mcd

mcd=[3 2 1; 
     2 1 1;
     1 1 0];
例如,它意味着组1自身有3个链接,组2有2个链接,组3有1个链接


IIUC并假设
ci
是一个排序数组,看起来基本上是在进行分块求和,但块大小不规则。因此,您可以使用一种方法,沿着行和列,然后在
ci
中的移位位置进行微分,这将基本上为您提供分块求和

实现如下所示-

%// Get cumulative sums row-wise and column-wise
csums = cumsum(cumsum(W,1),2)

%/ Get IDs of shifts and thus get cumsums at those positions
[~,idx] = unique(ci) %// OR find(diff([ci numel(ci)]))
csums_indexed = csums(idx,idx)

%// Get the  blockwise summations by differentiations on csums at shifts 
col1 = diff(csums_indexed(:,1),[],1)
row1 = diff(csums_indexed(1,:),[],2)
rest2D = diff(diff(csums_indexed,[],2),[],1)
out = [[csums_indexed(1,1) ; col1] [row1 ; rest2D]]

如果您不反对mex函数,可以使用下面的代码

测试代码 测试结果 使用这些设置的测试结果:

base avg fun time = 4.94275 
mex avg fun time = 0.0373092 
bsx avg fun time = 0.126406 
norm(x1 - x2) = 0
norm(x1 - x3) = 0
基本上,对于较小的
n_标签
,bsx函数做得很好,但您可以将其设置得足够大,以便mex函数更快

C++代码 将它放入像
interlink\u mex.cpp
这样的文件中,并使用
mex interlink\u mex.cpp
进行编译。你需要一个C++编译器在你的机器等…< /P>
#include "mex.h"
#include "matrix.h"
#include <math.h>

//  Author: Matthew Gunn

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
  if(nrhs != 2)
    mexErrMsgTxt("Invalid number of inputs.  Shoudl be 2 input argument.");

  if(nlhs != 1)
    mexErrMsgTxt("Invalid number of outputs.  Should be 1 output arguments.");

  if(!mxIsLogical(prhs[0])) {
    mexErrMsgTxt("First argument should be a logical array (i.e. type logical)");
  }
  if(!mxIsDouble(prhs[1])) {
    mexErrMsgTxt("Second argument should be an array of type double");

  }

  const mxArray *W = prhs[0];
  const mxArray *ci = prhs[1];

  size_t W_m = mxGetM(W);
  size_t W_n = mxGetN(W);

  if(W_m != W_n)
    mexErrMsgTxt("Rows and columns of W are not equal");

  //  size_t ci_m = mxGetM(ci);
  size_t ci_n = mxGetNumberOfElements(ci);


  mxLogical *W_data = mxGetLogicals(W);
  //  double *W_data = mxGetPr(W);
  double *ci_data = mxGetPr(ci);

  size_t *ci_data_size_t = (size_t*) mxCalloc(ci_n, sizeof(size_t));
  size_t ncomms = 0;

  double intpart;
  for(size_t i = 0; i < ci_n; i++) {
    double x = ci_data[i];
    if(x < 1 || x > 65536 || modf(x, &intpart) != 0.0) {
       mexErrMsgTxt("Input ci is not all integers from 1 to a maximum value of 65536 (can edit source code to change this)");

     }
    size_t xx = (size_t) x;
    if(xx > ncomms)
      ncomms = xx;
    ci_data_size_t[i] = xx - 1;
  }

  mxArray *mcd = mxCreateDoubleMatrix(ncomms, ncomms, mxREAL);
  double *mcd_data = mxGetPr(mcd);


  for(size_t i = 0; i < W_n; i++) {
    size_t ii = ci_data_size_t[i];
    for(size_t j = 0; j < W_n; j++) {  
      size_t jj = ci_data_size_t[j];
      mcd_data[ii + jj * ncomms] += (W_data[i + j * W_m] != 0);
    }    
  }
  for(size_t i = 0; i < ncomms * ncomms; i+= ncomms + 1) //go along diagonal
    mcd_data[i]/=2; //divide by 2

  mxFree(ci_data_size_t);
  plhs[0] = mcd;
}
#包括“mex.h”
#包括“矩阵h”
#包括
//作者:马修·冈恩
void MEX函数(int nlhs、mxArray*plhs[]、int nrhs、const mxArray*prhs[]){
如果(nrhs!=2)
MEXERMSGSTXT(“输入数无效。应为2个输入参数”);
如果(nlhs!=1)
mexErrMsgTxt(“输出数无效。应为1个输出参数”);
如果(!mxiLogical(prhs[0])){
mexErrMsgTxt(“第一个参数应该是逻辑数组(即逻辑类型)”;
}
如果(!mxIsDouble(prhs[1])){
mexErrMsgTxt(“第二个参数应该是double类型的数组”);
}
常量mxArray*W=prhs[0];
常量mxArray*ci=prhs[1];
尺寸W_m=mxGetM(W);
尺寸W n=mxGetN(W);
如果(W_m!=W_n)
MEXERMSGSTXT(“W的行和列不相等”);
//大小ci=mxGetM(ci);
大小ci=mxGetNumberOfElements(ci);
mxLogical*W_data=mxGetLogicals(W);
//双*W_数据=mxGetPr(W);
double*ci_data=mxGetPr(ci);
size_t*ci_data_size_t=(size_t*)mxCalloc(ci_n,sizeof(size_t));
大小\u t ncomms=0;
双内点;
对于(大小i=0;i65536 | | modf(x,&intpart)!=0.0){
mexErrMsgTxt(“输入ci不是从1到最大值65536的所有整数(可以编辑源代码以更改此值)”;
}
尺寸xx=(尺寸)x;
如果(xx>ncomms)
ncomms=xx;
ci_data_size_t[i]=xx-1;
}
mxArray*mcd=mxCreateDoubleMatrix(NCOMM、NCOMM、mxREAL);
双*mcd_数据=mxGetPr(mcd);
对于(大小i=0;i
这个怎么样

C = bsxfun(@eq, ci,unique(ci)');
mcd = C*W*C'
mcd(logical(eye(size(mcd)))) = mcd(logical(eye(size(mcd))))./2;

我想这就是你想要的。

你能进一步解释一下3乘3矩阵与你的图像的关系吗?您的示例中的
[1 1 2 3]
是您的
W
?如果是,创建该
mcd
属性需要什么
ci
?n有多大?W稀疏吗?ci是否包含从1到ncomms的整数值?在我的示例中,矩阵W是所示图形的邻接矩阵,
ci
是图形顶点的成员索引。
n
有多大并不重要,如果
W
是稀疏的,我只想避免双重嵌套循环。是的,
ci
包含从1到
ncomms
的整数,表示
i-th
顶点的组刚刚编辑了问题,以显示完整的邻接矩阵,如图所示。这正是我想要的,这个实现非常快,工作非常好,但我不理解它背后的逻辑。太好了!,这很简单,我会尽力解释的。第一个矩阵C是由1和0组成的NCOMMX长度(ci)矩阵。其中每行包含一个1,其中ci等于一个唯一值。现在让我们假设你想要计算dsm(1,2):你可以取第一个,你可以乘以W,从左边第一行,从右边第二行。这基本上是将W中的所有位置相加(乘以1或0,然后相加),其中C(1,:)等于1的行和C(2,:)等于1的列。通过做矩阵乘法,你可以一次完成所有的运算。我强烈建议你尝试一个简单的矩阵乘法例子,了解它是如何工作的!而且非常快。(对于大量标签,我的函数仍然更快:P)顺便说一句,我的方法中的大部分时间是生成矩阵C,特别是唯一函数。矩阵乘法非常快。我很想听到改进这一部分的方法……顺便说一句,我注意到如果
length(unique(ci))
不等于
max(ci)
,我的代码与您的代码并不完全相同,也就是说,我设置了ncomms=max(ci)。这就是为什么我在测试代码中添加了惟一调用,以使用IC变量作为标签id。
C = bsxfun(@eq, ci,unique(ci)');
mcd = C*W*C'
mcd(logical(eye(size(mcd)))) = mcd(logical(eye(size(mcd))))./2;