Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-12-01 18:44:05 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-12-01 18:44:05 +0300
commit80b4e907accb626ccd162bc9cd6263f177fd3cc9 (patch)
tree8001e6f711f27290bb4bc94277e429a3e798b952
parent22079a4aed6483d158075ea65b424fa488790ed8 (diff)
work on placeholders in NthOutBatch
-rw-r--r--src/amun/half/mblas/nth_element_kernels.h20
1 files changed, 16 insertions, 4 deletions
diff --git a/src/amun/half/mblas/nth_element_kernels.h b/src/amun/half/mblas/nth_element_kernels.h
index 5f616508..e784f98f 100644
--- a/src/amun/half/mblas/nth_element_kernels.h
+++ b/src/amun/half/mblas/nth_element_kernels.h
@@ -2,6 +2,7 @@
#include "matrix_wrapper.h"
#include "vector_wrapper.h"
+#include "thrust_functions.h"
namespace amunmt {
namespace GPUHalf {
@@ -44,13 +45,24 @@ struct NthOutBatch
//uint hypoInd;
//uint vocabInd;
- __device__ __host__
+ __host__
NthOutBatch(const float& rhs)
{
// only to be used to init variable in matrix.h gSum
assert(rhs == 0.0f);
ind = rhs;
- //score = rhs; //HH
+ score = float2half_rn(rhs);
+ //hypoInd = rhs;
+ //vocabInd = rhs;
+ }
+
+ __device__
+ NthOutBatch(const half& rhs)
+ {
+ // only to be used to init variable in matrix.h gSum
+ //assert(rhs == 0.0f);
+ ind = rhs;
+ score = rhs;
//hypoInd = rhs;
//vocabInd = rhs;
}
@@ -76,11 +88,11 @@ struct NthOutBatch
return *this;
}
- __device__ __host__
+ __device__
NthOutBatch& operator+=(const NthOutBatch& rhs)
{
ind += rhs.ind;
- //score += rhs.score; //HH
+ score += rhs.score;
//hypoInd += rhs.hypoInd;
//vocabInd += rhs.vocabInd;
return *this;