diff options
author | Guillaume Klein <guillaumekln@users.noreply.github.com> | 2021-04-30 13:41:11 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-30 13:41:11 +0300 |
commit | e4ab06cf9e7ba67548ec79d3bb4b90928c7cb6ae (patch) | |
tree | 7dc2c4a3a4fd252c5ffb7f0bc2fb05484f88d954 | |
parent | b1a46159ff5623406d3c5baa0280d12376754068 (diff) |
Add more checks before converting checkpoints to CTranslate2 (#2053)
-rwxr-xr-x | onmt/bin/release_model.py | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/onmt/bin/release_model.py b/onmt/bin/release_model.py index 357cb0d9..66ddb69e 100755 --- a/onmt/bin/release_model.py +++ b/onmt/bin/release_model.py @@ -2,14 +2,19 @@ import argparse import torch +from onmt.modules.position_ffn import ActivationFunction + def get_ctranslate2_model_spec(opt): """Creates a CTranslate2 model specification from the model options.""" with_relative_position = getattr(opt, "max_relative_positions", 0) > 0 + relu = ActivationFunction.relu is_ct2_compatible = ( opt.encoder_type == "transformer" and opt.decoder_type == "transformer" + and not getattr(opt, "aan_useffn", False) and getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot" + and getattr(opt, "pos_ffn_activation_fn", relu) == relu and ((opt.position_encoding and not with_relative_position) or (with_relative_position and not opt.position_encoding))) if not is_ct2_compatible: |