Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'python/ctranslate2/converters/transformers.py')
-rw-r--r--python/ctranslate2/converters/transformers.py19
1 files changed, 9 insertions, 10 deletions
diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py
index ba90d494..d8e88108 100644
--- a/python/ctranslate2/converters/transformers.py
+++ b/python/ctranslate2/converters/transformers.py
@@ -90,15 +90,20 @@ class ModelLoader(abc.ABC):
if isinstance(spec, model_spec.SequenceToSequenceModelSpec):
spec.register_source_vocabulary(tokens)
spec.register_target_vocabulary(tokens)
+
+ if spec.config.decoder_start_token is not None:
+ spec.config.decoder_start_token = tokenizer.decode(
+ model.config.decoder_start_token_id
+ )
else:
spec.register_vocabulary(tokens)
if tokenizer.bos_token is not None:
- spec.bos_token = tokenizer.bos_token
+ spec.config.bos_token = tokenizer.bos_token
if tokenizer.eos_token is not None:
- spec.eos_token = tokenizer.eos_token
+ spec.config.eos_token = tokenizer.eos_token
if tokenizer.unk_token is not None:
- spec.unk_token = tokenizer.unk_token
+ spec.config.unk_token = tokenizer.unk_token
return spec
@@ -147,7 +152,6 @@ class BartLoader(ModelLoader):
activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function],
layernorm_embedding=getattr(model.config, "normalize_embedding", True),
)
- spec.with_target_bos = False
self.set_encoder(spec.encoder, model.model.encoder)
self.set_decoder(spec.decoder, model.model.decoder)
@@ -325,7 +329,7 @@ class MBartLoader(BartLoader):
# MBart-25 passes the language code as the decoder start token.
if model.config.tokenizer_class in ("MBartTokenizer", None):
- spec.user_decoder_start_tokens = True
+ spec.config.decoder_start_token = None
return spec
@@ -336,11 +340,6 @@ class PegasusLoader(BartLoader):
def architecture_name(self):
return "PegasusForConditionalGeneration"
- def get_model_spec(self, model):
- spec = super().get_model_spec(model)
- spec.with_target_bos = True
- return spec
-
def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)
tokenizer.bos_token = tokens[model.config.pad_token_id]