diff options
author | Hieu Hoang <hieuhoang@gmail.com> | 2018-01-19 19:42:05 +0300 |
---|---|---|
committer | Hieu Hoang <hieuhoang@gmail.com> | 2018-01-19 19:42:05 +0300 |
commit | 6423317039a4fd6550b815d8afa8b8651205daf3 (patch) | |
tree | 1df77f7165e55c83ce32c2b74e210ec3e2af0360 | |
parent | 3e9e1c036a91b5168626704dc0342f225aa156f3 (diff) |
check last cuda error after running every kernel. Program may not be compiled for the particular GPU or shared mem incorrectly set
-rw-r--r-- | src/amun/gpu/dl4mt/gru.h | 1 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix.h | 2 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix_functions.cu | 15 | ||||
-rw-r--r-- | src/amun/gpu/mblas/matrix_functions.h | 6 | ||||
-rw-r--r-- | src/amun/gpu/mblas/nth_element.cu | 2 |
5 files changed, 26 insertions, 0 deletions
diff --git a/src/amun/gpu/dl4mt/gru.h b/src/amun/gpu/dl4mt/gru.h index 04217962..fd9e80dc 100644 --- a/src/amun/gpu/dl4mt/gru.h +++ b/src/amun/gpu/dl4mt/gru.h @@ -227,6 +227,7 @@ class FastGRU: public Cell { gElementwiseOps<<<blocks, threads, 0, mblas::CudaStreamHandler::GetStream()>>> (nextWrap, stateWrap, ruhWrap, tempWrap, bWrap, bx1Wrap, bx2Wrap); + HANDLE_ERROR(cudaGetLastError()); //PAUSE_TIMER("ElementwiseOps"); diff --git a/src/amun/gpu/mblas/matrix.h b/src/amun/gpu/mblas/matrix.h index 4cf4b710..a79f6c47 100644 --- a/src/amun/gpu/mblas/matrix.h +++ b/src/amun/gpu/mblas/matrix.h @@ -39,6 +39,8 @@ T Sum(const T *data, unsigned count) HANDLE_ERROR( cudaStreamSynchronize(stream)); gSum<<<1, 1, 0, stream>>>(data, count, *d_ret); + HANDLE_ERROR(cudaGetLastError()); + HANDLE_ERROR( cudaMemcpyAsync(&ret, d_ret, sizeof(T), cudaMemcpyDeviceToHost, stream) ); HANDLE_ERROR( cudaStreamSynchronize(stream)); diff --git a/src/amun/gpu/mblas/matrix_functions.cu b/src/amun/gpu/mblas/matrix_functions.cu index 75ab0dcd..59ba2af4 100644 --- a/src/amun/gpu/mblas/matrix_functions.cu +++ b/src/amun/gpu/mblas/matrix_functions.cu @@ -77,6 +77,7 @@ void Mean(Matrix& Out, gMean<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>> (outWrap, inWrap, sentenceLengthsWrap); + HANDLE_ERROR(cudaGetLastError()); } @@ -124,6 +125,7 @@ void WeightedMean(Matrix& Out,const Matrix& Weights, const Matrix& In, const mbl gWeightedMean<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>> (outWrap, weightsWrap, inWrap, mappingWrap); + HANDLE_ERROR(cudaGetLastError()); /* cerr << "nBlocks=" << nBlocks << endl; @@ -207,6 +209,7 @@ void PasteRows(Matrix& Out, const Matrix& In, const unsigned rowNo, unsigned col gPasteRows<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>> (outWrap, inWrap, rowNo, colNo); + HANDLE_ERROR(cudaGetLastError()); } @@ -286,6 +289,7 @@ Matrix& CopyRows(Matrix& Out, gCopyRows<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>> (outWrap, inWrap, indicesWrap); + HANDLE_ERROR(cudaGetLastError()); return Out; } @@ -344,6 +348,8 @@ Matrix& Slice(Matrix& Out, gSlice<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>> (outWrap, inWrap, n, dim); + HANDLE_ERROR(cudaGetLastError()); + return Out; } @@ -528,6 +534,7 @@ Matrix& Softmax(Matrix& Out, gSoftMax<<<blocks, threads, shared, CudaStreamHandler::GetStream()>>> (outWrap, batchIdsWrap, sentenceLengthsWrap, threads); + HANDLE_ERROR(cudaGetLastError()); return Out; } @@ -622,6 +629,7 @@ Matrix& LogSoftmax(Matrix& Out) gLogSoftMax<<<blocks, threads, shared, CudaStreamHandler::GetStream()>>> (Out, threads); + HANDLE_ERROR(cudaGetLastError()); return Out; } @@ -645,6 +653,7 @@ void SetColumn(Matrix& In, int noColumn, float value) { gSetColumn<<<nBlocks, nThreads, 0, mblas::CudaStreamHandler::GetStream()>>> (inWrap, noColumn, value); + HANDLE_ERROR(cudaGetLastError()); } __global__ void gFill(MatrixWrapper<float> in, float val) { @@ -665,6 +674,7 @@ void Fill(Matrix& In, float value) { gFill<<<nBlocks, nThreads, 0, CudaStreamHandler::GetStream()>>> (inWrap, value); + HANDLE_ERROR(cudaGetLastError()); } else { HANDLE_ERROR(cudaMemsetAsync(In.data(), 0, size * sizeof(float), CudaStreamHandler::GetStream())); @@ -706,6 +716,7 @@ void MapMatrix(Matrix& state, gMapMatrix<<<numBlocks, numThreads, 0, CudaStreamHandler::GetStream()>>> (stateWrap, sentenceLengthsWrap, i); + HANDLE_ERROR(cudaGetLastError()); /* cerr << "nBlocks=" << numBlocks << endl; @@ -826,6 +837,7 @@ void Normalization(Matrix &out, gLNormalization<<<numBlocks, numThreads, shared, CudaStreamHandler::GetStream()>>> (outWrap, inWrap, alphaWrap, *betaWrap, eps); + HANDLE_ERROR(cudaGetLastError()); /* //std::cerr << "nBlocks=" << numBlocks << std::endl; @@ -1361,6 +1373,7 @@ void LogSoftmaxAndNBest(mblas::Vector<NthOutBatch> &nBest, beamSizeSum, beamSizesWrap ); + HANDLE_ERROR(cudaGetLastError()); //PAUSE_TIMER("gBeamSizeInit"); /* @@ -1385,6 +1398,7 @@ void LogSoftmaxAndNBest(mblas::Vector<NthOutBatch> &nBest, forbidUNK, hypo2BeamSizeWrap, hypo2CandidateWrap); + HANDLE_ERROR(cudaGetLastError()); //PAUSE_TIMER("gLogSoftMax"); //HANDLE_ERROR( cudaStreamSynchronize(mblas::CudaStreamHandler::GetStream())); @@ -1402,6 +1416,7 @@ void LogSoftmaxAndNBest(mblas::Vector<NthOutBatch> &nBest, hypo2BeamSizeWrap, batch2HypoWrap, hypo2CandidateWrap); + HANDLE_ERROR(cudaGetLastError()); //PAUSE_TIMER("gNBestPerBatch"); //HANDLE_ERROR( cudaStreamSynchronize(mblas::CudaStreamHandler::GetStream())); diff --git a/src/amun/gpu/mblas/matrix_functions.h b/src/amun/gpu/mblas/matrix_functions.h index 1d30c15f..3c5dc04b 100644 --- a/src/amun/gpu/mblas/matrix_functions.h +++ b/src/amun/gpu/mblas/matrix_functions.h @@ -194,6 +194,7 @@ Matrix& Broadcast(Functor functor, gBroadcast<<<blocks, threads, 0, CudaStreamHandler::GetStream()>>> (functor, outWrap, in1Wrap, in2Wrap, batchMappingWrap); + HANDLE_ERROR(cudaGetLastError()); /* std::cerr << "size=" << size << std::endl; std::cerr << "nBlocks=" << blocks << std::endl; @@ -252,6 +253,7 @@ Matrix& BroadcastVecColumn(Functor functor, Matrix& Out, const mblas::Vector<flo gBroadcastVecColumn<<<blocks, threads, rows * sizeof(float), CudaStreamHandler::GetStream()>>> (functor, outWrap, inWrap); + HANDLE_ERROR(cudaGetLastError()); return Out; } @@ -296,6 +298,7 @@ Matrix& BroadcastVec(Functor functor, Matrix& Out, const Matrix& In) gBroadcastVec<<<blocks, threads, 0, stream>>> (functor, outWrap, inWrap); + HANDLE_ERROR(cudaGetLastError()); return Out; } @@ -323,6 +326,7 @@ Matrix& Element(Functor functor, gElement<<<blocks, threads, 0, stream>>> (functor, outWrap); + HANDLE_ERROR(cudaGetLastError()); return Out; } @@ -354,6 +358,7 @@ Matrix& Element(Functor functor, gElement<<<blocks, threads, 0, stream>>> (functor, outWrap, inWrap); + HANDLE_ERROR(cudaGetLastError()); return Out; } @@ -397,6 +402,7 @@ Matrix& Element(Functor functor, gElement<<<blocks, threads, 0, stream>>> (functor, outWrap, in1Wrap, in2Wrap); + HANDLE_ERROR(cudaGetLastError()); //HANDLE_ERROR( cudaPeekAtLastError() ); //HANDLE_ERROR( cudaDeviceSynchronize() ); diff --git a/src/amun/gpu/mblas/nth_element.cu b/src/amun/gpu/mblas/nth_element.cu index 9bdfe5aa..6df073c6 100644 --- a/src/amun/gpu/mblas/nth_element.cu +++ b/src/amun/gpu/mblas/nth_element.cu @@ -107,6 +107,7 @@ void NthElement::getNBestList(mblas::Matrix &probs, gMaxElement<<<numBlocks, BLOCK_SIZE, BLOCK_SIZE * sizeof(float), mblas::CudaStreamHandler::GetStream()>>> (outWrap, probsWrap, batchPositionWrap, numBatches); + HANDLE_ERROR(cudaGetLastError()); gMaxElementUpdate<<<numBatches, BLOCK_SIZE, BLOCK_SIZE * sizeof(float), mblas::CudaStreamHandler::GetStream()>>> (outWrap, @@ -115,6 +116,7 @@ void NthElement::getNBestList(mblas::Matrix &probs, batchPositionWrap, cumBeamSizesWrap, numBlocks); + HANDLE_ERROR(cudaGetLastError()); /* cerr << "numBlocks=" << numBlocks << endl; |