Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT-py.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2021-10-21 13:11:21 +0300
committerGitHub <noreply@github.com>2021-10-21 13:11:21 +0300
commite2628a2192b39a25606b1638ea2b2a38f21306e5 (patch)
treed281d74a3c6b909a5c5ca9e4e30c2abee3fc60f6
parent13bd3175a7222656090e878f0bdda0994f2a7256 (diff)
Use the same number of threads as PyTorch in CTranslate2 (#2117)
-rw-r--r--onmt/translate/translation_server.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/onmt/translate/translation_server.py b/onmt/translate/translation_server.py
index eedf555a..e0e968f5 100644
--- a/onmt/translate/translation_server.py
+++ b/onmt/translate/translation_server.py
@@ -111,7 +111,7 @@ class CTranslate2Translator(object):
default_for_translator = {
"inter_threads": 1,
- "intra_threads": 1,
+ "intra_threads": torch.get_num_threads(),
"compute_type": "default",
}
for name, value in default_for_translator.items():
@@ -189,9 +189,9 @@ class TranslationServer(object):
'model_root': conf.get('model_root', self.models_root),
'ct2_model': conf.get('ct2_model', None),
'ct2_translator_args': conf.get('ct2_translator_args',
- None),
+ {}),
'ct2_translate_batch_args': conf.get(
- 'ct2_translate_batch_args', None),
+ 'ct2_translate_batch_args', {}),
}
kwargs = {k: v for (k, v) in kwargs.items() if v is not None}
model_id = conf.get("id", None)