Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaume.klein@systrangroup.com>2018-06-08 18:41:24 +0300
committerGuillaume Klein <guillaume.klein@systrangroup.com>2018-06-08 18:41:24 +0300
commit84b4509327fa3d59acf9158ab8ddea29cc6e6aa2 (patch)
tree2f5d54d959159defad1faa2a82f4790438b35f6f
parent40da1ed2421770fde5485cf9c6dcfa570e749cf3 (diff)
fix vmapbenchmark-3
-rw-r--r--cli/translate.cc5
-rw-r--r--src/decoder.cc18
2 files changed, 11 insertions, 12 deletions
diff --git a/cli/translate.cc b/cli/translate.cc
index 22bf71eb..a4d15a50 100644
--- a/cli/translate.cc
+++ b/cli/translate.cc
@@ -117,12 +117,13 @@ int main(int argc, char* argv[]) {
size_t max_batch_size = argc > 1 ? std::stoi(argv[1]) : 1;
size_t beam_size = argc > 2 ? std::stoi(argv[2]) : 1;
size_t inter_threads = argc > 3 ? std::stoi(argv[3]) : 1;
- std::string model_path = argc > 4 ? argv[4] : "/home/klein/dev/ctransformer/model.bin";
+ std::string model_path = argc > 4 ? argv[4] : "/home/klein/dev/ctransformer/ende_transformer.bin";
+ std::string vmap = argc > 5 ? argv[5] : "";
std::string vocabulary_path = "/home/klein/data/wmt-ende/wmtende.vocab";
opennmt::TransformerModel model(model_path, vocabulary_path);
std::vector<opennmt::Translator> translator_pool;
- translator_pool.emplace_back(model, 200, beam_size, 0.6, "");
+ translator_pool.emplace_back(model, 200, beam_size, 0.6, vmap);
for (size_t i = 1; i < inter_threads; ++i) {
translator_pool.emplace_back(translator_pool.front());
}
diff --git a/src/decoder.cc b/src/decoder.cc
index 8acab390..13abc9e4 100644
--- a/src/decoder.cc
+++ b/src/decoder.cc
@@ -163,10 +163,12 @@ namespace opennmt {
topk_ids.reshape({cur_batch_size, beam_size});
gather_indices.resize({cur_batch_size, beam_size});
for (size_t i = 0; i < topk_ids.size(); ++i) {
- const auto flat_id = topk_ids.at<int32_t>(i);
- const auto beam_id = flat_id / vocabulary_size;
- const auto word_id = flat_id % vocabulary_size;
- const auto batch_id = i / beam_size;
+ auto flat_id = topk_ids.at<int32_t>(i);
+ auto beam_id = flat_id / vocabulary_size;
+ auto word_id = flat_id % vocabulary_size;
+ auto batch_id = i / beam_size;
+ if (!candidates.empty())
+ word_id = candidates.at<int32_t>(word_id);
topk_ids.at<int32_t>(i) = word_id;
gather_indices.at<int32_t>(i) = beam_id + batch_id * beam_size;
}
@@ -176,13 +178,9 @@ namespace opennmt {
size_t finished_count = 0;
for (size_t i = 0; i < cur_batch_size; ++i) {
auto pred_id = topk_ids.at<int32_t>({i, 0});
- if (!candidates.empty())
- pred_id = candidates.at<int32_t>(pred_id);
if (pred_id == static_cast<int32_t>(end_token) || step + 1 == max_steps) {
for (size_t t = 1; t < alive_seq.dim(-1); ++t) {
size_t id = alive_seq.at<int32_t>({i * beam_size, t});
- if (!candidates.empty())
- id = candidates.at<int32_t>(id);
if (id == end_token)
break;
sampled_ids[batch_offset[i]].push_back(id);
@@ -262,12 +260,12 @@ namespace opennmt {
if (!candidates.empty())
true_id = candidates.at<int32_t>(best);
size_t batch_id = batch_offset[i];
- if (best == end_token) {
+ if (true_id == end_token) {
finished[batch_id] = true;
finished_batch[i] = true;
one_finished = true;
} else {
- sample_from.at<int32_t>(i) = best;
+ sample_from.at<int32_t>(i) = true_id;
sampled_ids[batch_id].push_back(true_id);
++count_alive;
}