Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xsrc/graph/node_operators_unary.h2
-rwxr-xr-x[-rw-r--r--]src/models/encoder_decoder.cpp5
-rwxr-xr-xsrc/translator/beam_search.h2
-rwxr-xr-xsrc/translator/nth_element.cpp21
4 files changed, 15 insertions, 15 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index 190fa947..c5141092 100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -717,6 +717,8 @@ private:
public:
ReshapeNodeOp(Expr a, Shape shape) : UnaryNodeOp(a, shape, a->value_type()), reshapee_(a) {
+ ABORT_IF(a->shape().elements() != shape.elements(),
+ "Reshape must not change the number of elements (from {} to {})", a->shape().toString(), shape.toString());
Node::destroy_ = false;
}
diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp
index 4b072c73..33af5d01 100644..100755
--- a/src/models/encoder_decoder.cpp
+++ b/src/models/encoder_decoder.cpp
@@ -170,9 +170,8 @@ Ptr<DecoderState> EncoderDecoder::step(Ptr<ExpressionGraph> graph,
// create updated state that reflects reordering and dropping of hypotheses
state = hypIndices.empty() ? state : state->select(hypIndices, beamSize);
- // Fill stte with embeddings based on last prediction
- decoders_[0]->embeddingsFromPrediction(
- graph, state, embIndices, dimBatch, beamSize);
+ // Fill state with embeddings based on last prediction
+ decoders_[0]->embeddingsFromPrediction(graph, state, embIndices, dimBatch, beamSize);
auto nextState = decoders_[0]->step(graph, state);
return nextState;
diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h
index 60659a22..7e3cb385 100755
--- a/src/translator/beam_search.h
+++ b/src/translator/beam_search.h
@@ -267,7 +267,7 @@ public:
if(dimBatch > 1 && localBeamSize > 1)
expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [dimBatch, 1, localBeamSize, dimVocab]
else // (avoid copy if we can)
- expandedPathScores = reshape(expandedPathScores, {dimBatch, 1, (int)localBeamSize, expandedPathScores->shape()[-1]});
+ expandedPathScores = reshape(expandedPathScores, {expandedPathScores->shape()[-2], 1, expandedPathScores->shape()[-4], expandedPathScores->shape()[-1]});
// perform NN computation
if(t == 0)
diff --git a/src/translator/nth_element.cpp b/src/translator/nth_element.cpp
index f99b0be4..febf0739 100755
--- a/src/translator/nth_element.cpp
+++ b/src/translator/nth_element.cpp
@@ -43,7 +43,7 @@ private:
std::vector<int>::iterator middle = begin + beamSize;
std::vector<int>::iterator end = idxs.begin() + batchFirstElementIdxs[batchIdx + 1];
std::partial_sort(
- begin, middle, end, [=](int a, int b) { return scores[a] > scores[b]; });
+ begin, middle, end, [&](int a, int b) { return scores[a] > scores[b]; });
while(begin != middle) {
int idx = *begin++;
@@ -67,33 +67,32 @@ public:
const auto dimBatch = scores->shape()[-4];
ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??");
- const std::vector<size_t> beamSizes(dimBatch, N);
- std::vector<int> cumulativeBeamSizes(beamSizes.size() + 1, 0);
- std::vector<int> batchFirstElementIdxs(beamSizes.size() + 1, 0);
+ std::vector<int> cumulativeBeamSizes(dimBatch + 1, 0);
+ std::vector<int> batchFirstElementIdxs(dimBatch + 1, 0);
- for(int batchIdx = 0; batchIdx < beamSizes.size(); ++batchIdx) {
- cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + (int)beamSizes[batchIdx];
+ for(int batchIdx = 0; batchIdx < dimBatch; ++batchIdx) {
+ cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + (int)N;
ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != (batchIdx + 1) * N, "cumulativeBeamSizes wrong??");
batchFirstElementIdxs[batchIdx + 1]
+= (isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) * vocabSize;
ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??");
}
+ ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??");
size_t maxSize = N * dimBatch;
h_res.resize(maxSize);
h_res_idx.resize(maxSize);
selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes);
- getPairs(cumulativeBeamSizes.back(), outKeys, outPathScores);
- ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??");
+ getPairs(/*cumulativeBeamSizes.back(),*/ outKeys, outPathScores);
}
private:
- void getPairs(size_t number,
+ void getPairs(/*size_t number,*/
std::vector<unsigned>& outKeys,
std::vector<float>& outValues) {
- std::copy(h_res_idx.begin(), h_res_idx.begin() + number, std::back_inserter(outKeys));
- std::copy(h_res .begin(), h_res .begin() + number, std::back_inserter(outValues));
+ std::copy(h_res_idx.begin(), h_res_idx.end(), std::back_inserter(outKeys));
+ std::copy(h_res .begin(), h_res .end(), std::back_inserter(outValues));
//lastN_ = number;
}