Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/models/sequence_to_sequence.cc')
-rw-r--r--src/models/sequence_to_sequence.cc37
1 files changed, 21 insertions, 16 deletions
diff --git a/src/models/sequence_to_sequence.cc b/src/models/sequence_to_sequence.cc
index a0c61a93..c8616f5a 100644
--- a/src/models/sequence_to_sequence.cc
+++ b/src/models/sequence_to_sequence.cc
@@ -16,9 +16,9 @@ namespace ctranslate2 {
void SequenceToSequenceModel::load_vocabularies(ModelReader& model_reader) {
{
VocabularyInfo vocab_info;
- vocab_info.unk_token = get_attribute_with_default<std::string>("unk_token", "<unk>");
- vocab_info.bos_token = get_attribute_with_default<std::string>("bos_token", "<s>");
- vocab_info.eos_token = get_attribute_with_default<std::string>("eos_token", "</s>");
+ vocab_info.unk_token = config["unk_token"];
+ vocab_info.bos_token = config["bos_token"];
+ vocab_info.eos_token = config["eos_token"];
auto shared_vocabulary = model_reader.get_file(shared_vocabulary_file);
if (shared_vocabulary) {
@@ -63,11 +63,22 @@ namespace ctranslate2 {
}
void SequenceToSequenceModel::initialize(ModelReader& model_reader) {
+ if (binary_version() < 6) {
+ config["unk_token"] = get_attribute_with_default<std::string>("unk_token", "<unk>");
+ config["bos_token"] = get_attribute_with_default<std::string>("bos_token", "<s>");
+ config["eos_token"] = get_attribute_with_default<std::string>("eos_token", "</s>");
+ config["add_source_bos"] = get_flag_with_default("with_source_bos", false);
+ config["add_source_eos"] = get_flag_with_default("with_source_eos", false);
+
+ if (get_flag_with_default("user_decoder_start_tokens", false))
+ config["decoder_start_token"] = nullptr;
+ else if (get_flag_with_default("with_target_bos", true))
+ config["decoder_start_token"] = config["bos_token"];
+ else
+ config["decoder_start_token"] = config["eos_token"];
+ }
+
load_vocabularies(model_reader);
- _with_source_bos = get_flag_with_default("with_source_bos", false);
- _with_source_eos = get_flag_with_default("with_source_eos", false);
- _with_target_bos = get_flag_with_default("with_target_bos", true);
- _user_decoder_start_tokens = get_flag_with_default("user_decoder_start_tokens", false);
}
size_t SequenceToSequenceModel::num_source_vocabularies() const {
@@ -165,13 +176,7 @@ namespace ctranslate2 {
bool is_prefix) const {
const auto& target_vocabulary = _model->get_target_vocabulary();
const std::string* suffix = &target_vocabulary.eos_token();
- const std::string* prefix = nullptr;
- if (!_model->user_decoder_start_tokens()) {
- if (_model->with_target_bos())
- prefix = &target_vocabulary.bos_token();
- else
- prefix = &target_vocabulary.eos_token();
- }
+ const std::string* prefix = _model->decoder_start_token();
if (is_prefix) {
suffix = nullptr;
@@ -263,7 +268,7 @@ namespace ctranslate2 {
const std::vector<std::string>& target,
const ScoringOptions& options,
ScoringResult& result) {
- if (_model->user_decoder_start_tokens() && target.empty()) {
+ if (!_model->decoder_start_token() && target.empty()) {
return true;
}
@@ -419,7 +424,7 @@ namespace ctranslate2 {
if (!target.empty()) {
hypothesis = target;
- if (_model->user_decoder_start_tokens())
+ if (!_model->decoder_start_token())
hypothesis.erase(hypothesis.begin());
if (hypothesis.size() > options.max_decoding_length)
hypothesis.resize(options.max_decoding_length);