Warning: file_get_contents(/data/phpspider/zhask/data//catemap/9/java/370.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Java ND4J在GPU上运行较慢,但在CPU上运行较快_Java_Neural Network_Gpu_Cpu_Nd4j - Fatal编程技术网

Java ND4J在GPU上运行较慢,但在CPU上运行较快

Java ND4J在GPU上运行较慢,但在CPU上运行较快,java,neural-network,gpu,cpu,nd4j,Java,Neural Network,Gpu,Cpu,Nd4j,今天我将尝试在我的ND4J和DeepLearnNint4J项目中使用CUDA。之后,神经网络(从Keras引进)开始更快地工作。但下一个代码开始缓慢工作 我已经尝试将ND4J后端更改为本机(CPU),并得到了快速的结果 问题部分用注释突出显示(2行) 导入com.rabbitmq.client.Channel; 导入org.nd4j.linalg.api.ndarray.INDArray; 导入org.nd4j.linalg.factory.nd4j; 导入org.nd4j.linalg.ops

今天我将尝试在我的ND4J和DeepLearnNint4J项目中使用CUDA。之后,神经网络(从Keras引进)开始更快地工作。但下一个代码开始缓慢工作

我已经尝试将ND4J后端更改为本机(CPU),并得到了快速的结果

问题部分用注释突出显示(2行)

导入com.rabbitmq.client.Channel;
导入org.nd4j.linalg.api.ndarray.INDArray;
导入org.nd4j.linalg.factory.nd4j;
导入org.nd4j.linalg.ops.transforms.transforms;
导入java.io.IOException;
导入java.sql.*;
导入java.util.ArrayList;
导入java.util.List;
导入java.util.concurrent.ConcurrentHashMap;
公共类GraphUpdater实现可运行{
私人配对;
私有ConcurrentHashMapPubsList;
私有连接主节点;
专用连接;
专用信道;
GraphUpdater(Pair pubPair、ConcurrentHashMap pubsList、Channel Channel)抛出SQLException{
this.pubPair=pubPair;
this.channel=channel;
this.pubsList=pubsList;
connectionMain=DataBaseConnectionsPool.getConnection();
connectionSite=DataBaseConnectionsPool.getConnectionSite();
}
@凌驾
公开募捐{
试一试{
channel.basicAck(pubPair.deliveryTag,false);
}捕获(IOE异常){
System.out.println(“错误,pub=“+pubPair.pub”);
e、 printStackTrace();
}
编制报表st;
编制新报表;
试一试{
st=connectionMain.prepareStatement(“更新向量图集closed_pubs=closed_pubs | |其中pub=?”;
stNew=连接主准备状态(“插入向量图值(?)”;
语句psNew=connectionMain.createStatement();
结果集rs=psNew.executeQuery(“从新公共向量中选择*,其中pub=“+pubPair.pub”);
浮动[]_floatArr=新浮动[64];
while(rs.next()){
数组arr=rs.getArray(“向量”);
Object[]obj=(Object[])arr.getArray();
for(int-vIndex=0;vIndex<64;vIndex++){
_浮动arr[vIndex]=(浮动)(双)对象[vIndex];
}
pubsList.put(rs.getInt(1),Nd4j.create(_floatArr));
}
//任务中的发布X数据库中的所有发布
int pub=pubPair.pub;
List closed=新建ArrayList();
双平均值=0.96D;
INDArray currentVector=pubsList.get(pub);
//!%!%!%!%是代码的一部分
for(int pubId:pubsList.keySet()){
INDArray publicVector=pubsList.get(pubId);
if(currentVector==null | | pub==pubId | | publicVector==null){
继续;
}
//!%!%!%!%mega是代码的一部分,在VisualVM中约占99%的CPU时间
double dist=-Transforms.cosineddistance(currentVector,publicVector)+1;//从余弦sim传输到余弦dist
如果((距离-平均值)<0.01&(距离-平均值)>0){
平均值=(平均值+距离)/2;
}否则如果(距离>平均值){
平均值=距离;
关闭。清除();
圣克利尔巴奇();
}否则{
继续;
}
数组a=connectionMain.createArrayOf(“int”,新对象[]{pub});
st.setArray(1,a);
圣塞廷特(2,pubId);
st.addBatch();
closed.add(publid);
}
Object[]obj_vector=新对象[closed.size()];
对于(int i=0;i
我想从这个列表中了解一些东西:

  • 获得更快的结果并使用GPU

  • 关闭这部分代码的GPU,并将其保留为NN


  • 好的,我会将一部分代码以余弦距离重写到我自己的实现中

    好的,我会将一部分代码以余弦距离重写到我自己的实现中

    使用GPU不仅仅是为了加快速度。你设置问题的方式和架构完全不同。它们都有不同的方法。很可能,y我们的代码并没有像GPU那样设置为在500+核上并行运行。使用GPU不仅可以加快速度。设置问题的方式和体系结构完全不同。它们都有不同的方法。很可能,您的代码没有设置为在500+核上并行运行就像在GPU中一样。
    import com.rabbitmq.client.Channel;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.factory.Nd4j;
    import org.nd4j.linalg.ops.transforms.Transforms;
    
    import java.io.IOException;
    import java.sql.*;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.concurrent.ConcurrentHashMap;
    
    public class GraphUpdater implements Runnable {
        private Pair pubPair;
        private ConcurrentHashMap<Integer, INDArray> pubsList;
        private Connection connectionMain;
        private Connection connectionSite;
        private Channel channel;
    
        GraphUpdater(Pair pubPair, ConcurrentHashMap<Integer, INDArray> pubsList, Channel channel) throws SQLException {
        this.pubPair = pubPair;
        this.channel = channel;
        this.pubsList = pubsList;
        connectionMain = DataBaseConnectionsPool.getConnection();
        connectionSite = DataBaseConnectionsPool.getConnectionSite();
    }
    
    @Override
    public void run(){
        try {
            channel.basicAck(pubPair.deliveryTag, false);
        } catch (IOException e) {
            System.out.println("Error, pub="+pubPair.pub);
            e.printStackTrace();
        }
        PreparedStatement st;
        PreparedStatement stNew;
        try {
            st = connectionMain.prepareStatement("update vec_graph set closed_pubs=closed_pubs || ? where pub=?");
            stNew = connectionMain.prepareStatement("insert into vec_graph values (?, ?)");
    
            Statement psNew = connectionMain.createStatement();
            ResultSet rs = psNew.executeQuery("select * from new_public_vectors where pub="+pubPair.pub);
            float[] _floatArr = new float[64];
            while (rs.next()){
                Array arr = rs.getArray("vector");
                Object[] obj = (Object[]) arr.getArray();
                for (int vIndex=0; vIndex < 64; vIndex++){
                    _floatArr[vIndex] = (float)(double)obj[vIndex];
                }
                pubsList.put(rs.getInt(1), Nd4j.create(_floatArr));
            }
    
            //pub from task X all pubs from db
            int pub = pubPair.pub;
            List<Integer> closed = new ArrayList<>();
            double mean = 0.96D;
            INDArray currentVector = pubsList.get(pub);
            //!%!%!%!%slowly part of code
            for (int pubId : pubsList.keySet()) {
                INDArray publicVector = pubsList.get(pubId);
                if (currentVector == null || pub == pubId || publicVector == null){
                    continue;
                }
                //!%!%!%!%mega slowly part of code, ~99% of CPU time in VisualVM
                double dist = -Transforms.cosineDistance(currentVector, publicVector) + 1; // Transfer from cosine sim to cosine dist
                if ((dist - mean) < 0.01 && (dist - mean) > 0){
                    mean = (mean+dist)/2;
                }else if (dist > mean){
                    mean = dist;
                    closed.clear();
                    st.clearBatch();
                }else{
                    continue;
                }
                Array a = connectionMain.createArrayOf("int", new Object[]{pub});
                st.setArray(1, a);
                st.setInt(2, pubId);
                st.addBatch();
                closed.add(pubId);
            }
            Object[] obj_vector = new Object[closed.size()];
            for (int i = 0; i < closed.size(); i++){
                obj_vector[i] = closed.get(i);
            }
            Array closedArray = connectionMain.createArrayOf("int", obj_vector);
            stNew.setInt(1, pub);
            stNew.setArray(2, closedArray);
            stNew.addBatch();
    
            if (pubPair.byUser != 0){
                showToUser(closed, pub, pubPair.byUser);
            }
            try {
                st.executeBatch();
                stNew.executeBatch();
            }catch (BatchUpdateException e){
                e.printStackTrace();
                e.getNextException().printStackTrace();
            }
        } catch (BatchUpdateException e){
            e.printStackTrace();
            e.getNextException().printStackTrace();
        } catch (SQLException e) {
            e.printStackTrace();
        }finally {
            try {
                connectionMain.close();
                connectionSite.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }