diff options
-rwxr-xr-x | src/graph/node_operators_unary.h | 2 | ||||
-rwxr-xr-x[-rw-r--r--] | src/models/encoder_decoder.cpp | 5 | ||||
-rwxr-xr-x | src/translator/beam_search.h | 2 | ||||
-rwxr-xr-x | src/translator/nth_element.cpp | 21 |
4 files changed, 15 insertions, 15 deletions
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 190fa947..c5141092 100755 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -717,6 +717,8 @@ private: public: ReshapeNodeOp(Expr a, Shape shape) : UnaryNodeOp(a, shape, a->value_type()), reshapee_(a) { + ABORT_IF(a->shape().elements() != shape.elements(), + "Reshape must not change the number of elements (from {} to {})", a->shape().toString(), shape.toString()); Node::destroy_ = false; } diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index 4b072c73..33af5d01 100644..100755 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp @@ -170,9 +170,8 @@ Ptr<DecoderState> EncoderDecoder::step(Ptr<ExpressionGraph> graph, // create updated state that reflects reordering and dropping of hypotheses state = hypIndices.empty() ? state : state->select(hypIndices, beamSize); - // Fill stte with embeddings based on last prediction - decoders_[0]->embeddingsFromPrediction( - graph, state, embIndices, dimBatch, beamSize); + // Fill state with embeddings based on last prediction + decoders_[0]->embeddingsFromPrediction(graph, state, embIndices, dimBatch, beamSize); auto nextState = decoders_[0]->step(graph, state); return nextState; diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h index 60659a22..7e3cb385 100755 --- a/src/translator/beam_search.h +++ b/src/translator/beam_search.h @@ -267,7 +267,7 @@ public: if(dimBatch > 1 && localBeamSize > 1) expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [dimBatch, 1, localBeamSize, dimVocab] else // (avoid copy if we can) - expandedPathScores = reshape(expandedPathScores, {dimBatch, 1, (int)localBeamSize, expandedPathScores->shape()[-1]}); + expandedPathScores = reshape(expandedPathScores, {expandedPathScores->shape()[-2], 1, expandedPathScores->shape()[-4], expandedPathScores->shape()[-1]}); // perform NN computation if(t == 0) diff --git a/src/translator/nth_element.cpp b/src/translator/nth_element.cpp index f99b0be4..febf0739 100755 --- a/src/translator/nth_element.cpp +++ b/src/translator/nth_element.cpp @@ -43,7 +43,7 @@ private: std::vector<int>::iterator middle = begin + beamSize; std::vector<int>::iterator end = idxs.begin() + batchFirstElementIdxs[batchIdx + 1]; std::partial_sort( - begin, middle, end, [=](int a, int b) { return scores[a] > scores[b]; }); + begin, middle, end, [&](int a, int b) { return scores[a] > scores[b]; }); while(begin != middle) { int idx = *begin++; @@ -67,33 +67,32 @@ public: const auto dimBatch = scores->shape()[-4]; ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??"); - const std::vector<size_t> beamSizes(dimBatch, N); - std::vector<int> cumulativeBeamSizes(beamSizes.size() + 1, 0); - std::vector<int> batchFirstElementIdxs(beamSizes.size() + 1, 0); + std::vector<int> cumulativeBeamSizes(dimBatch + 1, 0); + std::vector<int> batchFirstElementIdxs(dimBatch + 1, 0); - for(int batchIdx = 0; batchIdx < beamSizes.size(); ++batchIdx) { - cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + (int)beamSizes[batchIdx]; + for(int batchIdx = 0; batchIdx < dimBatch; ++batchIdx) { + cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + (int)N; ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != (batchIdx + 1) * N, "cumulativeBeamSizes wrong??"); batchFirstElementIdxs[batchIdx + 1] += (isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) * vocabSize; ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??"); } + ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??"); size_t maxSize = N * dimBatch; h_res.resize(maxSize); h_res_idx.resize(maxSize); selectNBest(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes); - getPairs(cumulativeBeamSizes.back(), outKeys, outPathScores); - ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??"); + getPairs(/*cumulativeBeamSizes.back(),*/ outKeys, outPathScores); } private: - void getPairs(size_t number, + void getPairs(/*size_t number,*/ std::vector<unsigned>& outKeys, std::vector<float>& outValues) { - std::copy(h_res_idx.begin(), h_res_idx.begin() + number, std::back_inserter(outKeys)); - std::copy(h_res .begin(), h_res .begin() + number, std::back_inserter(outValues)); + std::copy(h_res_idx.begin(), h_res_idx.end(), std::back_inserter(outKeys)); + std::copy(h_res .begin(), h_res .end(), std::back_inserter(outValues)); //lastN_ = number; } |