如何优化HoldKarp算法的java实现以缩短运行时间?

如何优化HoldKarp算法的java实现以缩短运行时间?,java,algorithm,optimization,Java,Algorithm,Optimization,我使用Java实现的HOLD Karp algo来解决25个城市的TSP问题。 该计划通过了4个城市 当它在25个城市运行时,它不会停几个小时。我使用jVisualVM来查看热点是什么,经过一些优化后,现在它显示出来了 98%的时间用于实时计算,而不是Map.contains或Map.get 所以我想听听你的建议,下面是代码: private void solve() throws Exception { long beginTime = System.currentTi

我使用Java实现的HOLD Karp algo来解决25个城市的TSP问题。 该计划通过了4个城市

当它在25个城市运行时,它不会停几个小时。我使用jVisualVM来查看热点是什么,经过一些优化后,现在它显示出来了 98%的时间用于实时计算,而不是Map.contains或Map.get

所以我想听听你的建议,下面是代码:

    private void solve() throws Exception {
        long beginTime = System.currentTimeMillis();
        int counter = 0;

        List<BitSetEndPointID> previousCosts;
        List<BitSetEndPointID> currentCosts;
        //maximum number of elements is c(n,[n/2])
        //To calculate m-set's costs just need to keep (m-1)set's costs
        List<BitSetEndPointID> lastKeys = new ArrayList<BitSetEndPointID>();
        int m;
        if (totalNodes < 10) {
            //for test data, generate them on the fly
            SetUtil3.generateMSet(totalNodes);
        }
        //m=1
        BitSet beginSet = new BitSet();
        beginSet.set(0);
        previousCosts = new ArrayList<BitSetEndPointID>(1);
        BitSetEndPointID beginner = new BitSetEndPointID(beginSet, 0);
        beginner.setCost(0f);
        previousCosts.add(beginner);

        //for m=2 to totalNodes
        for (m = 2; m <= totalNodes; m++) {// sum(m=2..n 's C(n,m)*(m-1)(m-1)) ==> O(n^2 * 2^n)
            //pick m elements from total nodes, the element id is the index of nodeCoordinates
            // the first node is always present

            BitSet[] msets;
            if (totalNodes < 10) {
                msets = SetUtil3.msets[m - 1];
            } else {
                //for real data set, will read from serialized file
                msets = SetUtil3.getMsets(totalNodes, m-1);
            }
            currentCosts = new ArrayList<BitSetEndPointID>(msets.length);
            //System.out.println(m + " sets' size: " + msets.size());
            for (BitSet mset : msets) { //C(n,m) mset
                int[] candidates = allSetBits(mset, m);
                //mset is a BitSet which makes sure begin point 0 comes first
                //so end point candidate begins with 1. candidate[0] is always begin point 0
                for (int i = 1; i < candidates.length; i++) { // m-1 bits are set
                    //set the new last point as j, j must not be the same as begin point 0
                    int j = candidates[i];
                    //middleNodes = mset -{j}
                    BitSet middleNodes = (BitSet) mset.clone();
                    middleNodes.clear(j);
                    //loop through all possible points which are second to the last
                    //and get min(A[S-{j},k] + k->j), k!=j
                    float min = Float.MAX_VALUE;
                    int k;
                    for (int ki = 0; ki < candidates.length; ki++) {// m-1 calculation
                        k = candidates[ki];
                        if (k == j) continue;
                        float middleCost = 0;
                        BitSetEndPointID key = new BitSetEndPointID(middleNodes, k);
                        int index = previousCosts.indexOf(key);
                        if (index != -1) {
                            //System.out.println("get value from  map in m " + m + " y key " + middleNodes);
                            middleCost = previousCosts.get(index).getCost();
                        } else if (k == 0 && !middleNodes.equals(beginSet)) {
                            continue;
                        } else {
                            System.out.println("middleCost not found!");
                            continue;
//                            System.exit(-1);
                        }


                        float lastCost = distances[k][j];
                        float cost = middleCost + lastCost;
                        if (cost < min) {
                            min = cost;
                        }

                        counter++;
                        if (counter % 500000 == 0) {
                            try {
                                Thread.currentThread().sleep(100);
                            } catch (InterruptedException iex) {
                                System.out.println("Who dares interrupt my precious sleep?!");
                            }
                        }
                    }
                    //set the costs for chosen mset and last point j
                    BitSetEndPointID key = new BitSetEndPointID(mset, j);
                    key.setCost(min);
                    currentCosts.add(key);

//                    System.out.println("===========================================>mset " + mset + " and end at " +
//                            j + " 's min cost: " + min);
//                    if (m == totalNodes) {
//                        lastKeys.add(key);
//                    }
                }
            }
            previousCosts = currentCosts;
            System.out.println("...");
        }

        calcLastStop(lastKeys, previousCosts);
        System.out.println(" cost " + (System.currentTimeMillis() - beginTime) / 60000 + " minutes.");
    }


    private void calcLastStop(List<BitSetEndPointID> lastKeys, List<BitSetEndPointID>  costs) {
        //last step, calculate the min(A[S={1..n},k] +k->1)
        float finalMinimum = Float.MAX_VALUE;
        for (BitSetEndPointID key : costs) {
            float middleCost = key.getCost();
            Integer endPoint = key.lastPointID;
            float lastCost = distances[endPoint][0];
            float cost = middleCost + lastCost;
            if (cost < finalMinimum) {
                finalMinimum = cost;
            }
        }
        System.out.println("final result: " + finalMinimum);
    }
private void solve()引发异常{
long beginTime=System.currentTimeMillis();
int计数器=0;
列出以前的费用;
列出当前成本;
//元素的最大数量为c(n,[n/2])
//要计算m-set的成本,只需要保留(m-1)set的成本
List lastKeys=new ArrayList();
int m;
if(总节点数<10){
//对于测试数据,动态生成它们
SetUtil3.generateMSet(totalNodes);
}
//m=1
BitSet beginSet=新位集();
beginSet.set(0);
以前的成本=新阵列列表(1);
BitSetEndPointID初学者=新的BitSetEndPointID(beginSet,0);
初学者。设置成本(0f);
以前的成本。添加(初学者);
//对于m=2到totalNodes
对于(m=2;m mset“+mset+”和结束于”+
//j+的最小成本:“+min”;
//if(m==totalNodes){
//添加(键);
//                    }
}
}
以前的成本=当前成本;
System.out.println(“…”);
}
calcLastStop(最后一个键、以前的成本);
System.out.println(“cost”+(System.currentTimeMillis()-beginTime)/60000+“分钟”);
}
私有void calcLastStop(列出LastKey,列出成本){
//最后一步,计算最小值(A[S={1..n},k]+k->1)
float finalMinimum=float.MAX_值;
for(BitSetEndPointID密钥:成本){
float middleCost=key.getCost();
整数端点=key.lastPointID;
浮动成本=距离[终点][0];
浮动成本=中间成本+上次成本;
如果(成本<最终最小值){
最终最小值=成本;
}
}
System.out.println(“最终结果:+finalMinimum”);
}

您可以通过使用原语数组(可能需要比对象列表更好的内存布局)和直接操作位掩码(没有位集或其他对象)来加快代码的速度。以下是一些代码(它生成随机图,但您可以轻松更改它,以便它读取图形):


也许我没有说清楚,我的意思是在优化之前,我花了很多时间在Map方法上。现在我使用Object[]替换Map。我投票决定将此问题作为主题外的问题结束,因为它属于主题外的问题。@NicoSchertler,我同意您的意见,并尝试进行类似的投票,但它只给了我5个选项,说明此问题应转移到何处。这可能是一个愚蠢的问题,但当您标记为主题外的问题时,如何指定您想要的StackExchange组?@Choirbean你必须写一篇自定义评论。@NicoSchertler没有。OP不是要求评论,他们只是想让它更快。优化完全围绕着这个主题。它运行得非常完美!你能解释一下这些问题吗?1.(面具和面具)的含义是什么(1@user1532146当且仅当第i个节点已在路径中时,才在掩码中设置第i个位。前两个表达式检查该位是否已设置。最后一个表达式设置对应于下一个节点的位。
import java.io.*;
import java.util.*;

class Main {

    final static float INF = 1e10f;

    public static void main(String[] args) {
        final int n = 25;
        float[][] dist = new float[n][n];
        Random random = new Random();
        for (int i = 0; i < n; i++)
            for (int j = i + 1; j < n; j++)
                dist[i][j] = dist[j][i] = random.nextFloat();
        float[][] dp = new float[n][1 << n];
        for (int i = 0; i < dp.length; i++)
            Arrays.fill(dp[i], INF);
        dp[0][1] = 0.0f;
        for (int mask = 1; mask < (1 << n); mask++) {
            for (int lastNode = 0; lastNode < n; lastNode++) {
                if ((mask & (1 << lastNode)) == 0)
                    continue; 
                for (int nextNode = 0; nextNode < n; nextNode++) {
                    if ((mask & (1 << nextNode)) != 0)
                        continue;
                    dp[nextNode][mask | (1 << nextNode)] = Math.min(
                            dp[nextNode][mask | (1 << nextNode)],
                            dp[lastNode][mask] + dist[lastNode][nextNode]);
                }
            }   
        }
        double res = INF;
        for (int lastNode = 0; lastNode < n; lastNode++)
            res = Math.min(res, dist[lastNode][0] + dp[lastNode][(1 << n) - 1]);
        System.out.println(res);
    }
}
time java Main
...
real    2m5.546s
user    2m2.264s
sys     0m1.572s