我正在尝试使用推力库和库布拉斯库实现一个在GPU上运行的神经网络,但要使它比当前的多线程和矢量化CPU实现运行得更快,我遇到了很多困难。该网络有一个包含后勤单位的隐藏层和一个包含线性单位的输出层,下面是代码:


// Functor to add bias before computing logistic
template <typename T>
struct bias_logistic_f {
        __host__ __device__
        T operator()(const T& x, const T& y) const {
                return 1/(1+exp(-(x+y)));
bias_logistic_f bias_logistic();

// Thrust vectors for input/hidden/output units
thrust::device_vector<FLT> batch(batch_rows*ndim);
thrust::device_vector<FLT> hid(batch_rows*nhid);
thrust::device_vector<FLT> gpu_code(ndata*ncode);

// ...Load data and network weights...

// Multiply input (batch) by weights (vis2hid)
// Our matrices are stored row-major, but BLAS wants column-major,
// so pretend they're transposed and compute hid' = vis2hid' * batch'
cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, nhid, batch_rows, ndim,
            &alpha, thrust::raw_pointer_cast(&vis2hid[0]), nhid,
                    thrust::raw_pointer_cast(&batch[0]), ndim,
             &beta, thrust::raw_pointer_cast(&hid[0]), nhid);

// Add hidbiases to hid and compute logistic
thrust::transform(hid.begin(), hid.end(), hidbiases.begin(), hid.begin(),

// Multiply hid by weights (hid2code)
cublasDgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, ncode, batch_rows, nhid,
            &alpha, thrust::raw_pointer_cast(&hid2code[0]), ncode,
                    thrust::raw_pointer_cast(&hid[0]), nhid,
             &beta, thrust::raw_pointer_cast(&gpu_code[b*batch_rows*ncode]), ncode);

// Add codebiases
thrust::transform(gpu_code.begin() + b*batch_rows*ncode, gpu_code.begin() + (b+1)*batch_rows*ncode,
                  codebiases.begin(), gpu_code.begin() + b*batch_rows*ncode,

for(int b=0; b<nbatch; ++b) {
    // Zero out batch b
    thrust::fill(batch.begin(), batch.end(), 0.0f);
    // batch_val contains the non-zero values for the current batch, batch_idx the indices within the batch,
    // and batch_ptr indexes into batch_val/batch_idx
    // This is like CSR format except instead of compressing rows, it's compressing submatrices of 1,000 rows
    thrust::scatter(batch_val.begin() + batch_ptr[b],
                    batch_val.begin() + batch_ptr[b+1],
                    batch_idx.begin() + batch_ptr[b],

    // ...Input batch to network (shown above)...

