dmvnorm MVN density-RcppArmadillo实现比包含一点Fortran的R包慢

dmvnorm MVN density-RcppArmadillo实现比包含一点Fortran的R包慢,r,rcpp,R,Rcpp,解决方案现在在中联机 我从RcppArmadillo中的mvtnorm包重新实现了dmvnorm。我有点喜欢犰狳,但我想它在普通的Rcpp中也会起作用。dmvnorm的方法基于马氏距离,所以我有一个函数,然后是多元正态密度函数 让我向您展示我的代码: #include <RcppArmadillo.h> #include <Rcpp.h> // [[Rcpp::depends("RcppArmadillo")]] // [[Rcpp::export]] arma::

解决方案现在在中联机


我从RcppArmadillo中的mvtnorm包重新实现了dmvnorm。我有点喜欢犰狳,但我想它在普通的Rcpp中也会起作用。dmvnorm的方法基于马氏距离,所以我有一个函数,然后是多元正态密度函数

让我向您展示我的代码:

#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]

// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::mat mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }



// [[Rcpp::export]]
arma::vec dmvnorm ( arma::mat x,  arma::mat mean,  arma::mat sigma, bool log){ 

arma::vec distval = mahalanobis_arma(x,  mean, sigma);

    double logdet = sum(arma::log(arma::eig_sym(sigma)));
    double log2pi = 1.8378770664093454835606594728112352797227949472755668;
    arma::vec logretval = -( (x.n_cols * log2pi + logdet + distval)/2  ) ;

       if(log){ 
         return(logretval);

       }else { 
       return(exp(logretval));
         }
}
不-(

[编辑]

问题是: 1) 为什么RcppArmadillo实现比普通的R实现慢? 2) 如何创建优于R实现的Rcpp/RcppArmadillo实现

[编辑2]


我将mahalanobis_arma放入mvtnorm::dmvnorm函数中,它也会变慢。

如果您想更快地实现mahalanobis距离,只需重新编写算法并模仿R使用的算法。这非常简单

我稍微修改了你的函数
mahalanobis_arma
,把
mu
变成了
rowvec

基本上我只是把R代码翻译成RcppArmadillo

mahalanobis
function (x, center, cov, inverted = FALSE, ...) 
{
    x <- if (is.vector(x)) 
        matrix(x, ncol = length(x))
    else as.matrix(x)
    x <- sweep(x, 2, center)
    if (!inverted) 
        cov <- solve(cov, ...)
    setNames(rowSums((x %*% cov) * x), rownames(x))
}
<bytecode: 0x6e5b408>
<environment: namespace:stats>
正如您所看到的,新实现比R实现更快。 我很确定,通过使用cholesky分解来求解协方差矩阵,或者使用其他矩阵分解,我们可以做得更好

最后,我们可以将这个
Mahalanobis
函数插入您的
dmvnorm
并测试它:

require(mvtnorm)
set.seed(1)
sigma <- matrix(c(4, 2, 2, 3), ncol = 2)
x <- rmvnorm(n = 5000000, mean = c(1, 2), sigma = sigma, method = "chol")


all.equal(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          c(dmvnorm(x, t(1:2), .2+diag(2), FALSE)))
## [1] TRUE

benchmark(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          dmvnorm(x, t(1:2), .2+diag(2), FALSE),
          order = "elapsed")

##                                                test replications
## 2          dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
## 1 mvtnorm::dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
##   elapsed relative user.self sys.self user.child sys.child
## 2  35.366    1.000    31.117    4.193          0         0
## 1  60.770    1.718    56.666   13.236          0         0
require(mvtnorm)
种子(1)

sigma我不明白你的问题是什么如果大部分计算将由两个实现之间的同一个线性代数库执行,为什么你希望看到显著的改进?这只是表明你可以用任何语言编写慢代码。:)为什么不简单的调用<代码> MVTNORM::从C++中的DVNOMUNE <代码>?你的标题标题是误导的。RCppArmadillo不慢于R;它比R+Fortran慢。Fortran位恰好比R位更重要。这一点的底线是,用低级语言重新编码通常只有在相关操作尚未通过原始R函数中的编译二进制代码时才有帮助……还有一件事:我可能有不同的方法。因此,我必须事先减去它们,然后用平均值0运行此操作。@Inferator您可以再次将
mu
转换为
mat
,方法是确保小心地将其减去
x
(按行)并且维度足够。对。太好了,谢谢!简单但聪明的解决方案。这对Rcpp图库不是很好吗?我把所有内容都放在一个要点中,并添加了openMP:-只需在R中运行Sys.setenv(“PKG_CXXFLAGS”=“-fopenmp”)Sys.setenv(“PKG_LIBS”=“-fopenmp”),然后使用sourceCpp@Inferrator嗯,我认为这可能是一个很好的补充Rcpp画廊,我认为他们接受的要点,如果你想添加一个条目。
mahalanobis
function (x, center, cov, inverted = FALSE, ...) 
{
    x <- if (is.vector(x)) 
        matrix(x, ncol = length(x))
    else as.matrix(x)
    x <- sweep(x, 2, center)
    if (!inverted) 
        cov <- solve(cov, ...)
    setNames(rowSums((x %*% cov) * x), rownames(x))
}
<bytecode: 0x6e5b408>
<environment: namespace:stats>
#include <RcppArmadillo.h>
#include <Rcpp.h>

// [[Rcpp::depends("RcppArmadillo")]]
// [[Rcpp::export]]
arma::vec Mahalanobis(arma::mat x, arma::rowvec center, arma::mat cov){
    int n = x.n_rows;
    arma::mat x_cen;
    x_cen.copy_size(x);
    for (int i=0; i < n; i++) {
        x_cen.row(i) = x.row(i) - center;
    }
    return sum((x_cen * cov.i()) % x_cen, 1);    
}


// [[Rcpp::export]]
arma::vec mahalanobis_arma( arma::mat x ,  arma::rowvec mu, arma::mat sigma ){

  int n = x.n_rows;
  arma::vec md(n);
    for (int i=0; i<n; i++){
        arma::mat x_i = x.row(i) - mu;
        arma::mat Y = arma::solve( sigma, arma::trans(x_i) );
        md(i) = arma::as_scalar(x_i * Y);
    }
    return md;

    }
require(RcppArmadillo)
sourceCpp("mahalanobis.cpp")

set.seed(1)
x <- matrix(rnorm(10000 * 10), ncol = 10)
Sx <- cov(x)


all.equal(c(Mahalanobis(x, colMeans(x), Sx))
          ,mahalanobis(x, colMeans(x), Sx))
## [1] TRUE

all.equal(mahalanobis_arma(x, colMeans(x), Sx)
          ,Mahalanobis(x, colMeans(x), Sx))
## [1] TRUE


require(rbenchmark)
benchmark(Mahalanobis(x, colMeans(x), Sx),
          mahalanobis(x, colMeans(x), Sx),
          mahalanobis_arma(x, colMeans(x), Sx),
          order = "elapsed")


##                                   test replications elapsed
## 1      Mahalanobis(x, colMeans(x), Sx)          100   0.124
## 2      mahalanobis(x, colMeans(x), Sx)          100   0.741
## 3 mahalanobis_arma(x, colMeans(x), Sx)          100   4.509
##   relative user.self sys.self user.child sys.child
## 1    1.000     0.173    0.077          0         0
## 2    5.976     0.804    0.670          0         0
## 3   36.363     4.386    4.626          0         0
require(mvtnorm)
set.seed(1)
sigma <- matrix(c(4, 2, 2, 3), ncol = 2)
x <- rmvnorm(n = 5000000, mean = c(1, 2), sigma = sigma, method = "chol")


all.equal(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          c(dmvnorm(x, t(1:2), .2+diag(2), FALSE)))
## [1] TRUE

benchmark(mvtnorm::dmvnorm(x, t(1:2), .2 + diag(2), FALSE),
          dmvnorm(x, t(1:2), .2+diag(2), FALSE),
          order = "elapsed")

##                                                test replications
## 2          dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
## 1 mvtnorm::dmvnorm(x, t(1:2), 0.2 + diag(2), FALSE)          100
##   elapsed relative user.self sys.self user.child sys.child
## 2  35.366    1.000    31.117    4.193          0         0
## 1  60.770    1.718    56.666   13.236          0         0