Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/amun/gpu/mblas/nth_element.cu')
-rw-r--r--src/amun/gpu/mblas/nth_element.cu12
1 files changed, 9 insertions, 3 deletions
diff --git a/src/amun/gpu/mblas/nth_element.cu b/src/amun/gpu/mblas/nth_element.cu
index 177ccfb8..c4a426a1 100644
--- a/src/amun/gpu/mblas/nth_element.cu
+++ b/src/amun/gpu/mblas/nth_element.cu
@@ -154,15 +154,21 @@ void NthElement::GetPairs(uint number,
void NthElement::getValueByKey(std::vector<float>& out, const mblas::Matrix &d_in) const
{
// need a model with multiple scorers to test this method
- assert(false);
+ out.resize(d_breakdown.size());
mblas::VectorWrapper<float> breakdownWrap(d_breakdown);
const mblas::MatrixWrapper<float> inWrap(d_in);
//gGetValueByKey<<<1, lastN_, 0, stream_>>>
// (breakdownWrap, inWrap, h_res_idx, lastN_);
-
- HANDLE_ERROR( cudaMemcpyAsync(out.data(), d_breakdown.data(), h_res.size() * sizeof(float),
+ /*
+ cerr << "out="
+ << out.size() << " "
+ << d_breakdown.size() << " "
+ << h_res.size()
+ << endl;
+ */
+ HANDLE_ERROR( cudaMemcpyAsync(out.data(), d_breakdown.data(), d_breakdown.size() * sizeof(float),
cudaMemcpyDeviceToHost, mblas::CudaStreamHandler::GetStream()) );
HANDLE_ERROR( cudaStreamSynchronize(mblas::CudaStreamHandler::GetStream()));
}