Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHieu Hoang <hieuhoang@gmail.com>2017-12-15 00:46:36 +0300
committerHieu Hoang <hieuhoang@gmail.com>2017-12-15 00:46:36 +0300
commit71cfd04c863b6adeeb3877532c2dd4ad4dfaff61 (patch)
tree50b19d9306814c55cf4ca01209a6590ac179922a /src/amun/gpu
parent9f4addbf22d9396b89184db4ac231756b0f77510 (diff)
nbest for ensemble
Diffstat (limited to 'src/amun/gpu')
-rw-r--r--src/amun/gpu/mblas/nth_element.cu12
1 files changed, 9 insertions, 3 deletions
diff --git a/src/amun/gpu/mblas/nth_element.cu b/src/amun/gpu/mblas/nth_element.cu
index 177ccfb8..c4a426a1 100644
--- a/src/amun/gpu/mblas/nth_element.cu
+++ b/src/amun/gpu/mblas/nth_element.cu
@@ -154,15 +154,21 @@ void NthElement::GetPairs(uint number,
void NthElement::getValueByKey(std::vector<float>& out, const mblas::Matrix &d_in) const
{
// need a model with multiple scorers to test this method
- assert(false);
+ out.resize(d_breakdown.size());
mblas::VectorWrapper<float> breakdownWrap(d_breakdown);
const mblas::MatrixWrapper<float> inWrap(d_in);
//gGetValueByKey<<<1, lastN_, 0, stream_>>>
// (breakdownWrap, inWrap, h_res_idx, lastN_);
-
- HANDLE_ERROR( cudaMemcpyAsync(out.data(), d_breakdown.data(), h_res.size() * sizeof(float),
+ /*
+ cerr << "out="
+ << out.size() << " "
+ << d_breakdown.size() << " "
+ << h_res.size()
+ << endl;
+ */
+ HANDLE_ERROR( cudaMemcpyAsync(out.data(), d_breakdown.data(), d_breakdown.size() * sizeof(float),
cudaMemcpyDeviceToHost, mblas::CudaStreamHandler::GetStream()) );
HANDLE_ERROR( cudaStreamSynchronize(mblas::CudaStreamHandler::GetStream()));
}