2012-07-04 4 views
1

Я пытаюсь использовать функцию cublasSgemmBatched() в jcuda для матричного умножения, и я не уверен, как правильно обрабатывать передачу указателя и векторы пакетных матриц. Я буду очень благодарен, если кто-то знает, как изменить мой код, чтобы правильно справиться с этой проблемой. В этом примере массив C остается неизменным после cublasGetVector.cublasSgemmBatched использование с jcuda

public static void SsmmBatchJCublas(int m, int n, int k, float A[], float B[]){ 

    // Create a CUBLAS handle 
    cublasHandle handle = new cublasHandle(); 
    cublasCreate(handle); 

    // Allocate memory on the device 
    Pointer d_A = new Pointer(); 
    Pointer d_B = new Pointer(); 
    Pointer d_C = new Pointer(); 


    cudaMalloc(d_A, m*k * Sizeof.FLOAT); 
    cudaMalloc(d_B, n*k * Sizeof.FLOAT); 
    cudaMalloc(d_C, m*n * Sizeof.FLOAT); 

    float[] C = new float[m*n]; 
    // Copy the memory from the host to the device 
    cublasSetVector(m*k, Sizeof.FLOAT, Pointer.to(A), 1, d_A, 1); 
    cublasSetVector(n*k, Sizeof.FLOAT, Pointer.to(B), 1, d_B, 1); 
    cublasSetVector(m*n, Sizeof.FLOAT, Pointer.to(C), 1, d_C, 1); 

    Pointer[] Aarray = new Pointer[]{d_A}; 
    Pointer AarrayPtr = Pointer.to(Aarray); 
    Pointer[] Barray = new Pointer[]{d_B}; 
    Pointer BarrayPtr = Pointer.to(Barray); 
    Pointer[] Carray = new Pointer[]{d_C}; 
    Pointer CarrayPtr = Pointer.to(Carray); 

    // Execute sgemm 
    Pointer pAlpha = Pointer.to(new float[]{1}); 
    Pointer pBeta = Pointer.to(new float[]{0}); 


    cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, pAlpha, AarrayPtr, Aarray.length, BarrayPtr, Barray.length, pBeta, CarrayPtr, Carray.length, Aarray.length); 
    // Copy the result from the device to the host 
    cublasGetVector(m*n, Sizeof.FLOAT, d_C, 1, Pointer.to(C), 1); 

    // Clean up 
    cudaFree(d_A); 
    cudaFree(d_B); 
    cudaFree(d_C); 
    cublasDestroy(handle); 
} 

ответ

1

Я спросил на официальном форуме jcuda и быстро получил ответ here.

+0

Пожалуйста, отредактируйте этот ответ, чтобы включить решение. – ThiefMaster