diff options
author | Guillaume Klein <guillaume.klein@systrangroup.com> | 2018-06-08 18:41:24 +0300 |
---|---|---|
committer | Guillaume Klein <guillaume.klein@systrangroup.com> | 2018-06-08 18:41:24 +0300 |
commit | 84b4509327fa3d59acf9158ab8ddea29cc6e6aa2 (patch) | |
tree | 2f5d54d959159defad1faa2a82f4790438b35f6f | |
parent | 40da1ed2421770fde5485cf9c6dcfa570e749cf3 (diff) |
fix vmapbenchmark-3
-rw-r--r-- | cli/translate.cc | 5 | ||||
-rw-r--r-- | src/decoder.cc | 18 |
2 files changed, 11 insertions, 12 deletions
diff --git a/cli/translate.cc b/cli/translate.cc index 22bf71eb..a4d15a50 100644 --- a/cli/translate.cc +++ b/cli/translate.cc @@ -117,12 +117,13 @@ int main(int argc, char* argv[]) { size_t max_batch_size = argc > 1 ? std::stoi(argv[1]) : 1; size_t beam_size = argc > 2 ? std::stoi(argv[2]) : 1; size_t inter_threads = argc > 3 ? std::stoi(argv[3]) : 1; - std::string model_path = argc > 4 ? argv[4] : "/home/klein/dev/ctransformer/model.bin"; + std::string model_path = argc > 4 ? argv[4] : "/home/klein/dev/ctransformer/ende_transformer.bin"; + std::string vmap = argc > 5 ? argv[5] : ""; std::string vocabulary_path = "/home/klein/data/wmt-ende/wmtende.vocab"; opennmt::TransformerModel model(model_path, vocabulary_path); std::vector<opennmt::Translator> translator_pool; - translator_pool.emplace_back(model, 200, beam_size, 0.6, ""); + translator_pool.emplace_back(model, 200, beam_size, 0.6, vmap); for (size_t i = 1; i < inter_threads; ++i) { translator_pool.emplace_back(translator_pool.front()); } diff --git a/src/decoder.cc b/src/decoder.cc index 8acab390..13abc9e4 100644 --- a/src/decoder.cc +++ b/src/decoder.cc @@ -163,10 +163,12 @@ namespace opennmt { topk_ids.reshape({cur_batch_size, beam_size}); gather_indices.resize({cur_batch_size, beam_size}); for (size_t i = 0; i < topk_ids.size(); ++i) { - const auto flat_id = topk_ids.at<int32_t>(i); - const auto beam_id = flat_id / vocabulary_size; - const auto word_id = flat_id % vocabulary_size; - const auto batch_id = i / beam_size; + auto flat_id = topk_ids.at<int32_t>(i); + auto beam_id = flat_id / vocabulary_size; + auto word_id = flat_id % vocabulary_size; + auto batch_id = i / beam_size; + if (!candidates.empty()) + word_id = candidates.at<int32_t>(word_id); topk_ids.at<int32_t>(i) = word_id; gather_indices.at<int32_t>(i) = beam_id + batch_id * beam_size; } @@ -176,13 +178,9 @@ namespace opennmt { size_t finished_count = 0; for (size_t i = 0; i < cur_batch_size; ++i) { auto pred_id = topk_ids.at<int32_t>({i, 0}); - if (!candidates.empty()) - pred_id = candidates.at<int32_t>(pred_id); if (pred_id == static_cast<int32_t>(end_token) || step + 1 == max_steps) { for (size_t t = 1; t < alive_seq.dim(-1); ++t) { size_t id = alive_seq.at<int32_t>({i * beam_size, t}); - if (!candidates.empty()) - id = candidates.at<int32_t>(id); if (id == end_token) break; sampled_ids[batch_offset[i]].push_back(id); @@ -262,12 +260,12 @@ namespace opennmt { if (!candidates.empty()) true_id = candidates.at<int32_t>(best); size_t batch_id = batch_offset[i]; - if (best == end_token) { + if (true_id == end_token) { finished[batch_id] = true; finished_batch[i] = true; one_finished = true; } else { - sample_from.at<int32_t>(i) = best; + sample_from.at<int32_t>(i) = true_id; sampled_ids[batch_id].push_back(true_id); ++count_alive; } |