Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2018-01-16 14:47:35 +0300
committerHieu Hoang <hieuhoang@gmail.com>2018-01-16 14:47:35 +0300
commitd65c22ccaa046a1b445471dfea4819f0acb1129d (patch)
tree02fc7a002d9cdb1730aed1b86bb63ffecf69769c
parent8443ddb445f89e1bc123c58a81f50d1272da986b (diff)
add Offset(). Start doing MergingElement()
-rw-r--r--src/amun/gpu/mblas/matrix_functions.cu17
-rw-r--r--src/amun/gpu/mblas/vector_wrapper.h17
2 files changed, 27 insertions, 7 deletions
diff --git a/src/amun/gpu/mblas/matrix_functions.cu b/src/amun/gpu/mblas/matrix_functions.cu
index 1625b67d..d1c50933 100644
--- a/src/amun/gpu/mblas/matrix_functions.cu
+++ b/src/amun/gpu/mblas/matrix_functions.cu
@@ -956,14 +956,14 @@ void AddElement(float &minScore,
__device__
void MergeElement(float &minScore,
- NthOutBatch *arr,
+ VectorWrapper<NthOutBatch> &vec,
unsigned arrSize,
const NthOutBatch &ele)
{
float newMinScore = HIGHEST_FLOAT;
bool found = false;
for (unsigned i = 0; i < arrSize; ++i) {
- NthOutBatch &currEle = arr[i];
+ NthOutBatch &currEle = vec[i];
if (!found && minScore == currEle.score) {
currEle = ele;
found = true;
@@ -981,6 +981,7 @@ void MergeElement(float &minScore,
__device__
void MergeElement(float &minScore,
NthOutBatch *arr,
+ VectorWrapper<NthOutBatch> &vec,
unsigned arrSize,
const NthOutBatch &ele,
bool forbidUNK,
@@ -991,7 +992,7 @@ void MergeElement(float &minScore,
}
else if (ele.score > minScore) {
// replace element with min score
- MergeElement(minScore, arr, arrSize, ele);
+ MergeElement(minScore, vec, arrSize, ele);
/*
printf("arrInd=%d ind=%d vocabId=%d \n",
@@ -1050,7 +1051,7 @@ void NBestAndMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap,
unsigned arrInd = hypoInd * vocabSize + vocabInd;
NthOutBatch ele(arrInd, score, hypoInd, vocabInd);
- MergeElement(minScore, arr, beamSize, ele, forbidUNK, vocabInd);
+ MergeElement(minScore, arr, row, beamSize, ele, forbidUNK, vocabInd);
vocabInd += blockDim.x;
} // while (vocabInd < vocabSize) {
@@ -1061,12 +1062,11 @@ void NBestAndMax(VectorWrapper<NthOutBatch> &nBestCandidatesWrap,
__syncthreads();
int skip = (len + 1) >> 1;
if (threadIdx.x < (len >> 1)) {
- NthOutBatch *dest = &nBestMatrix(threadIdx.x);
for (unsigned i = 0; i < beamSize; ++i) {
const NthOutBatch &ele = nBestMatrix(threadIdx.x + skip, i);
if (ele.score > minScore) {
- MergeElement(minScore, dest, beamSize, ele);
+ MergeElement(minScore, row, beamSize, ele);
}
}
}
@@ -1253,6 +1253,9 @@ __global__ void gNBestPerBatch(VectorWrapper<NthOutBatch> nBestWrap,
// candidates from other previous hypos
if (!isFirst) {
+ assert(nextHypoInd < nBestWrap.size());
+ VectorWrapper<NthOutBatch> offset = nBestWrap.Offset(nextHypoInd);
+
for (unsigned hypoOffset = 1; hypoOffset < beamSize; ++hypoOffset) {
//printf("hypoInd=%d \n", (hypoInd + hypoOffset));
@@ -1272,7 +1275,7 @@ __global__ void gNBestPerBatch(VectorWrapper<NthOutBatch> nBestWrap,
NthOutBatch *arr = &nBestWrap[nextHypoInd];
if (candidate.score > minScore) {
- MergeElement(minScore, arr, beamSize, candidate);
+ MergeElement(minScore, offset, beamSize, candidate);
}
}
}
diff --git a/src/amun/gpu/mblas/vector_wrapper.h b/src/amun/gpu/mblas/vector_wrapper.h
index a544f83e..98ccfb85 100644
--- a/src/amun/gpu/mblas/vector_wrapper.h
+++ b/src/amun/gpu/mblas/vector_wrapper.h
@@ -1,4 +1,5 @@
#pragma once
+#include <sstream>
#include "matrix.h"
#include "gpu/mblas/vector.h"
@@ -67,6 +68,22 @@ public:
return data()[i];
}
+ __device__
+ VectorWrapper<T> Offset(unsigned offset)
+ {
+ T &ele = (*this)[offset];
+ VectorWrapper<T> ret(&ele, size_ - offset);
+ return ret;
+ }
+
+ std::string Debug() const
+ {
+ std::stringstream strm;
+ strm << "size_=" << size_;
+
+ return strm.str();
+ }
+
protected:
unsigned size_;