diff options
Diffstat (limited to 'python/ctranslate2/converters/transformers.py')
-rw-r--r-- | python/ctranslate2/converters/transformers.py | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index ba90d494..d8e88108 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -90,15 +90,20 @@ class ModelLoader(abc.ABC): if isinstance(spec, model_spec.SequenceToSequenceModelSpec): spec.register_source_vocabulary(tokens) spec.register_target_vocabulary(tokens) + + if spec.config.decoder_start_token is not None: + spec.config.decoder_start_token = tokenizer.decode( + model.config.decoder_start_token_id + ) else: spec.register_vocabulary(tokens) if tokenizer.bos_token is not None: - spec.bos_token = tokenizer.bos_token + spec.config.bos_token = tokenizer.bos_token if tokenizer.eos_token is not None: - spec.eos_token = tokenizer.eos_token + spec.config.eos_token = tokenizer.eos_token if tokenizer.unk_token is not None: - spec.unk_token = tokenizer.unk_token + spec.config.unk_token = tokenizer.unk_token return spec @@ -147,7 +152,6 @@ class BartLoader(ModelLoader): activation=_SUPPORTED_ACTIVATIONS[model.config.activation_function], layernorm_embedding=getattr(model.config, "normalize_embedding", True), ) - spec.with_target_bos = False self.set_encoder(spec.encoder, model.model.encoder) self.set_decoder(spec.decoder, model.model.decoder) @@ -325,7 +329,7 @@ class MBartLoader(BartLoader): # MBart-25 passes the language code as the decoder start token. if model.config.tokenizer_class in ("MBartTokenizer", None): - spec.user_decoder_start_tokens = True + spec.config.decoder_start_token = None return spec @@ -336,11 +340,6 @@ class PegasusLoader(BartLoader): def architecture_name(self): return "PegasusForConditionalGeneration" - def get_model_spec(self, model): - spec = super().get_model_spec(model) - spec.with_target_bos = True - return spec - def get_vocabulary(self, model, tokenizer): tokens = super().get_vocabulary(model, tokenizer) tokenizer.bos_token = tokens[model.config.pad_token_id] |