Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-12-01 15:02:01 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-12-01 15:02:01 +0300
commitbd35ba821d8dc26e43c50110c4ebfecca9559a8f (patch)
treeb39f870b1263c57065a62e37c01cac00052d3852
parent53b2aa714a0fcfcde81a6e0e6b14cac00673f979 (diff)
begin CopyNthOutBatch()
-rw-r--r--src/amun/half/decoder/best_hyps.cu14
-rw-r--r--src/amun/half/decoder/best_hyps.h2
-rw-r--r--src/amun/half/mblas/matrix_functions.cu17
-rw-r--r--src/amun/half/mblas/matrix_functions.h3
4 files changed, 23 insertions, 13 deletions
diff --git a/src/amun/half/decoder/best_hyps.cu b/src/amun/half/decoder/best_hyps.cu
index fdcb6209..c59c169b 100644
--- a/src/amun/half/decoder/best_hyps.cu
+++ b/src/amun/half/decoder/best_hyps.cu
@@ -198,21 +198,11 @@ void BestHyps::getNBestList(const std::vector<uint>& beamSizes,
//cerr << endl;
}
-void BestHyps::GetPairs(mblas::Vector<NthOutBatch> &nBest,
+void BestHyps::GetPairs(const mblas::Vector<NthOutBatch> &nBest,
std::vector<uint>& outKeys,
std::vector<float>& outValues) const
{
- //cerr << "top=" << top2.size() << " nBest=" << nBest.size() << endl;
- outKeys.resize(nBest.size());
- outValues.resize(nBest.size());
-
- std::vector<NthOutBatch> hostVec(nBest.size());
- mblas::copy(nBest.data(), nBest.size(), hostVec.data(), cudaMemcpyDeviceToHost);
-
- for (size_t i = 0; i < nBest.size(); ++i) {
- outKeys[i] = hostVec[i].ind;
- outValues[i] = half2float(hostVec[i].score);
- }
+ CopyNthOutBatch(nBest, outKeys, outValues);
}
} // namespace
diff --git a/src/amun/half/decoder/best_hyps.h b/src/amun/half/decoder/best_hyps.h
index 275bd019..545b5f68 100644
--- a/src/amun/half/decoder/best_hyps.h
+++ b/src/amun/half/decoder/best_hyps.h
@@ -58,7 +58,7 @@ class BestHyps : public BestHypsBase
std::vector<uint>& outKeys,
const bool isFirst=false) const;
- void GetPairs(mblas::Vector<NthOutBatch> &nBest,
+ void GetPairs(const mblas::Vector<NthOutBatch> &nBest,
std::vector<uint>& outKeys,
std::vector<float>& outValues) const;
diff --git a/src/amun/half/mblas/matrix_functions.cu b/src/amun/half/mblas/matrix_functions.cu
index 54e06319..a8b89630 100644
--- a/src/amun/half/mblas/matrix_functions.cu
+++ b/src/amun/half/mblas/matrix_functions.cu
@@ -1451,6 +1451,23 @@ void TestMemCpy()
cerr << "Finished" << endl;
}
+void CopyNthOutBatch(const mblas::Vector<NthOutBatch> &nBest,
+ std::vector<uint>& outKeys,
+ std::vector<float>& outValues)
+{
+ //cerr << "top=" << top2.size() << " nBest=" << nBest.size() << endl;
+ outKeys.resize(nBest.size());
+ outValues.resize(nBest.size());
+
+ std::vector<NthOutBatch> hostVec(nBest.size());
+ mblas::copy(nBest.data(), nBest.size(), hostVec.data(), cudaMemcpyDeviceToHost);
+
+ for (size_t i = 0; i < nBest.size(); ++i) {
+ outKeys[i] = hostVec[i].ind;
+ outValues[i] = half2float(hostVec[i].score);
+ }
+}
+
} // namespace mblas
} // namespace GPU
} // namespace amunmt
diff --git a/src/amun/half/mblas/matrix_functions.h b/src/amun/half/mblas/matrix_functions.h
index e0aeda97..a7fb6cff 100644
--- a/src/amun/half/mblas/matrix_functions.h
+++ b/src/amun/half/mblas/matrix_functions.h
@@ -479,6 +479,9 @@ void TestMemCpy(size_t size, const T *data1)
void TestMemCpy();
+void CopyNthOutBatch(const mblas::Vector<NthOutBatch> &nBest,
+ std::vector<uint>& outKeys,
+ std::vector<float>& outValues);
} // namespace mblas
} // namespace GPU