Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2021-04-30 13:41:11 +0300
committerGitHub <noreply@github.com>2021-04-30 13:41:11 +0300
commite4ab06cf9e7ba67548ec79d3bb4b90928c7cb6ae (patch)
tree7dc2c4a3a4fd252c5ffb7f0bc2fb05484f88d954
parentb1a46159ff5623406d3c5baa0280d12376754068 (diff)
Add more checks before converting checkpoints to CTranslate2 (#2053)
-rwxr-xr-xonmt/bin/release_model.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/onmt/bin/release_model.py b/onmt/bin/release_model.py
index 357cb0d9..66ddb69e 100755
--- a/onmt/bin/release_model.py
+++ b/onmt/bin/release_model.py
@@ -2,14 +2,19 @@
import argparse
import torch
+from onmt.modules.position_ffn import ActivationFunction
+
def get_ctranslate2_model_spec(opt):
"""Creates a CTranslate2 model specification from the model options."""
with_relative_position = getattr(opt, "max_relative_positions", 0) > 0
+ relu = ActivationFunction.relu
is_ct2_compatible = (
opt.encoder_type == "transformer"
and opt.decoder_type == "transformer"
+ and not getattr(opt, "aan_useffn", False)
and getattr(opt, "self_attn_type", "scaled-dot") == "scaled-dot"
+ and getattr(opt, "pos_ffn_activation_fn", relu) == relu
and ((opt.position_encoding and not with_relative_position)
or (with_relative_position and not opt.position_encoding)))
if not is_ct2_compatible: