Python Softmax的CUDA实现

Python Softmax的CUDA实现,python,cuda,numba,Python,Cuda,Numba,我希望使用CUDA提高我的softmax层的速度。由于缺少python和CUDA的示例,我希望在这里得到一些建议。我已经建立了一个幼稚的实现,并在寻找关于这一点的建议 @cuda.jit def softmax(X, w, b): m = X.shape[0] probs = np.zeros((m, 120)) startX=cuda.grid(2) gridX=cuda.gridDim.x * cuda.blockDim.x; for i in r

我希望使用CUDA提高我的softmax层的速度。由于缺少python和CUDA的示例,我希望在这里得到一些建议。我已经建立了一个幼稚的实现,并在寻找关于这一点的建议

@cuda.jit
def softmax(X, w, b):

    m = X.shape[0]

    probs = np.zeros((m, 120))
    startX=cuda.grid(2)
    gridX=cuda.gridDim.x * cuda.blockDim.x;
    for i in range(startX, m):
        X_slice = X[i,:,:,:]
        z = np.dot(X_slice,w).reshape(1, w.shape[-1])
        z_exp = np.exp(z) 
        z_probs = z_exp/np.sum(z_exp) 
        probs[i,:] =z_probs

    A_prev = (X, w, b)
    return probs, A_prev
我已经建立了一个幼稚的实现,并在寻找关于这个问题的建议

好的,让我们根据文档来看看您的代码:

@cuda.jit
def softmax(X, w, b):

    m = X.shape[0]                                     # OK

    probs = np.zeros((m, 120))                         # Illegal, array creation
    startX=cuda.grid(2)                                # OK
    gridX=cuda.gridDim.x * cuda.blockDim.x;            # OK but unused
    for i in range(startX, m):                         # Illegal, startX is not a scalar
        X_slice = X[i,:,:,:]                           # Illegal, array creation
        z = np.dot(X_slice,w).reshape(1, w.shape[-1])  # Illegal, array method
        z_exp = np.exp(z)                              # Illegal, array method
        z_probs = z_exp/np.sum(z_exp)                  # Illegal, array method
        probs[i,:] =z_probs                            # OK (except probs can't be created in kernel)

    A_prev = (X, w, b)                                 # OK tuples are supported
    return probs, A_prev                               # illegal, kernels can't return anything
因此,基本上,您的整个代码都是非法的,永远无法工作。如果您需要一些建议,请参阅我链接到的文档:

为了获得最佳性能,用户应该编写这样的代码:每个线程一次只处理一个元素


我的。CUDA Python编程不仅仅是在您已有的一些代码前面添加
@CUDA.jit
,并期望它能快速运行。您需要想象完整算法的元素代码,然后编写代码来执行该操作。对于上面的代码,这可能需要不止一个内核。

我的建议是尝试运行代码,当代码不可避免地无法运行时,返回一个具体的问题,您需要回答这个问题。在做这件事之前,我会告诉你们关于Numba文档的内容