diff options
Diffstat (limited to 'src/models/sequence_to_sequence.cc')
-rw-r--r-- | src/models/sequence_to_sequence.cc | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/src/models/sequence_to_sequence.cc b/src/models/sequence_to_sequence.cc index a0c61a93..c8616f5a 100644 --- a/src/models/sequence_to_sequence.cc +++ b/src/models/sequence_to_sequence.cc @@ -16,9 +16,9 @@ namespace ctranslate2 { void SequenceToSequenceModel::load_vocabularies(ModelReader& model_reader) { { VocabularyInfo vocab_info; - vocab_info.unk_token = get_attribute_with_default<std::string>("unk_token", "<unk>"); - vocab_info.bos_token = get_attribute_with_default<std::string>("bos_token", "<s>"); - vocab_info.eos_token = get_attribute_with_default<std::string>("eos_token", "</s>"); + vocab_info.unk_token = config["unk_token"]; + vocab_info.bos_token = config["bos_token"]; + vocab_info.eos_token = config["eos_token"]; auto shared_vocabulary = model_reader.get_file(shared_vocabulary_file); if (shared_vocabulary) { @@ -63,11 +63,22 @@ namespace ctranslate2 { } void SequenceToSequenceModel::initialize(ModelReader& model_reader) { + if (binary_version() < 6) { + config["unk_token"] = get_attribute_with_default<std::string>("unk_token", "<unk>"); + config["bos_token"] = get_attribute_with_default<std::string>("bos_token", "<s>"); + config["eos_token"] = get_attribute_with_default<std::string>("eos_token", "</s>"); + config["add_source_bos"] = get_flag_with_default("with_source_bos", false); + config["add_source_eos"] = get_flag_with_default("with_source_eos", false); + + if (get_flag_with_default("user_decoder_start_tokens", false)) + config["decoder_start_token"] = nullptr; + else if (get_flag_with_default("with_target_bos", true)) + config["decoder_start_token"] = config["bos_token"]; + else + config["decoder_start_token"] = config["eos_token"]; + } + load_vocabularies(model_reader); - _with_source_bos = get_flag_with_default("with_source_bos", false); - _with_source_eos = get_flag_with_default("with_source_eos", false); - _with_target_bos = get_flag_with_default("with_target_bos", true); - _user_decoder_start_tokens = get_flag_with_default("user_decoder_start_tokens", false); } size_t SequenceToSequenceModel::num_source_vocabularies() const { @@ -165,13 +176,7 @@ namespace ctranslate2 { bool is_prefix) const { const auto& target_vocabulary = _model->get_target_vocabulary(); const std::string* suffix = &target_vocabulary.eos_token(); - const std::string* prefix = nullptr; - if (!_model->user_decoder_start_tokens()) { - if (_model->with_target_bos()) - prefix = &target_vocabulary.bos_token(); - else - prefix = &target_vocabulary.eos_token(); - } + const std::string* prefix = _model->decoder_start_token(); if (is_prefix) { suffix = nullptr; @@ -263,7 +268,7 @@ namespace ctranslate2 { const std::vector<std::string>& target, const ScoringOptions& options, ScoringResult& result) { - if (_model->user_decoder_start_tokens() && target.empty()) { + if (!_model->decoder_start_token() && target.empty()) { return true; } @@ -419,7 +424,7 @@ namespace ctranslate2 { if (!target.empty()) { hypothesis = target; - if (_model->user_decoder_start_tokens()) + if (!_model->decoder_start_token()) hypothesis.erase(hypothesis.begin()); if (hypothesis.size() > options.max_decoding_length) hypothesis.resize(options.max_decoding_length); |