Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2018-01-16 14:29:24 +0300
committerHieu Hoang <hieuhoang@gmail.com>2018-01-16 14:29:24 +0300
commit8443ddb445f89e1bc123c58a81f50d1272da986b (patch)
tree2dd4495204bba967f6e6e6cc858b209123b89151
parent79167a699fa6c0b30c506d0d892f8b3f1c4fa8ec (diff)
use Row() in AddElement()
-rw-r--r--src/amun/gpu/mblas/matrix_functions.cu9
1 files changed, 5 insertions, 4 deletions
diff --git a/src/amun/gpu/mblas/matrix_functions.cu b/src/amun/gpu/mblas/matrix_functions.cu
index 26a0c9f9..1625b67d 100644
--- a/src/amun/gpu/mblas/matrix_functions.cu
+++ b/src/amun/gpu/mblas/matrix_functions.cu
@@ -931,7 +931,7 @@ float GetMaxScore(const MatrixWrapper<NthOutBatch> &nBestMatrix)
__device__
void AddElement(float &minScore,
unsigned &i,
- NthOutBatch *arr,
+ VectorWrapper<NthOutBatch> &vec,
bool forbidUNK,
unsigned vocabInd,
const NthOutBatch &ele)
@@ -939,11 +939,11 @@ void AddElement(float &minScore,
const float score = ele.score;
if (forbidUNK && vocabInd == UNK_ID) {
- arr[i].score = LOWEST_FLOAT;
+ vec[i].score = LOWEST_FLOAT;
minScore = LOWEST_FLOAT;
}
else {
- arr[i] = ele;
+ vec[i] = ele;
if (score < minScore) {
minScore = score;
@@ -1021,6 +1021,7 @@ void NBestAndMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap,
void *ptrOffset = _sharePtr + sizeof(float) * blockDim.x;
MatrixWrapper<NthOutBatch> nBestMatrix((NthOutBatch*)ptrOffset, blockDim.x, maxBeamSize, 1, 1);
NthOutBatch *arr = &nBestMatrix(threadIdx.x);
+ VectorWrapper<NthOutBatch> row = nBestMatrix.Row(threadIdx.x);
unsigned vocabSize = in.dim(1);
@@ -1038,7 +1039,7 @@ void NBestAndMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap,
unsigned arrInd = hypoInd * vocabSize + vocabInd;
NthOutBatch ele(arrInd, score, hypoInd, vocabInd);
- AddElement(minScore, i, arr, forbidUNK, vocabInd, ele);
+ AddElement(minScore, i, row, forbidUNK, vocabInd, ele);
vocabInd += blockDim.x;
}