带dask的混淆矩阵

带dask的混淆矩阵,dask,Dask,我试图使用Dask计算混淆矩阵元素。 从算法的角度来看,我的实现似乎还可以。 但是,当我在两个大小都为100万的阵列上运行它时,它会花费很长时间 有人对如何优化这段代码有什么建议吗 def confusion_matrix_dask(truth,predictions,labels_list=[]): TP=0 FP=0 FN=0 TN=0 if not labels_list: TP=(truth[predictions==1]==1).

我试图使用Dask计算混淆矩阵元素。 从算法的角度来看,我的实现似乎还可以。 但是,当我在两个大小都为100万的阵列上运行它时,它会花费很长时间

有人对如何优化这段代码有什么建议吗

def confusion_matrix_dask(truth,predictions,labels_list=[]):
    TP=0
    FP=0
    FN=0
    TN=0
    if not labels_list:
        TP=(truth[predictions==1]==1).sum()
        FP=(truth[predictions!=1]==1).sum()
        TN=(truth[predictions!=1]!=1).sum()
        FN=(truth[predictions==1]!=1).sum()
    for label in labels_list:
        TP=(truth[predictions==label]==label).sum()+TP
        FP=(truth[predictions!=label]==label).sum()+FP
        TN=(truth[predictions!=label]!=label).sum()+TN
        FN=(truth[predictions==label]!=label).sum()+FN


    return np.array([[TN.compute(), FP.compute()] , [TN.compute() ,FN.compute()]])

您应该注意的一个快速改进:

import dask
TP, FP, TN, FN = dask.compute(TP, FP, TN, FN)
而不是对每个对象调用
.compute()
。这将共享公共数据和任务,从而减少要完成的总工作量