Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'python/tests/test_transformers.py')
-rw-r--r--python/tests/test_transformers.py38
1 files changed, 25 insertions, 13 deletions
diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py
index 9e0a646d..baaf7ef4 100644
--- a/python/tests/test_transformers.py
+++ b/python/tests/test_transformers.py
@@ -226,7 +226,8 @@ def test_transformers_lm_scoring(tmpdir):
"device", ["cpu"] + (["cuda"] if ctranslate2.get_cuda_device_count() > 0 else [])
)
@pytest.mark.parametrize("return_log_probs", [True, False])
-def test_transformers_lm_forward(tmpdir, device, return_log_probs):
+@pytest.mark.parametrize("tensor_input", [True, False])
+def test_transformers_lm_forward(tmpdir, device, return_log_probs, tensor_input):
import torch
import transformers
@@ -239,29 +240,40 @@ def test_transformers_lm_forward(tmpdir, device, return_log_probs):
output_dir = converter.convert(output_dir)
generator = ctranslate2.Generator(output_dir, device=device)
- inputs = tokenizer(["Hello world!"], return_tensors="pt")
-
- inputs.to(device)
- model.to(device)
+ text = ["Hello world!"]
with torch.no_grad():
+ inputs = tokenizer(text, return_tensors="pt")
+ inputs.to(device)
+ model.to(device)
output = model(**inputs)
ref_output = output.logits
if return_log_probs:
ref_output = torch.nn.functional.log_softmax(ref_output, dim=-1)
ref_output = ref_output.cpu().numpy()
- ids = inputs["input_ids"].to(torch.int32)
- lengths = inputs["attention_mask"].sum(1, dtype=torch.int32)
+ kwargs = dict(return_log_probs=return_log_probs)
- if device == "cpu":
- ids = ids.numpy()
- lengths = lengths.numpy()
+ if tensor_input:
+ inputs = tokenizer(text, return_length=True, return_tensors="pt")
+ inputs.to(device)
+ ids = inputs.input_ids.to(torch.int32)
+ lengths = inputs.length.to(torch.int32)
+
+ if device == "cpu":
+ ids = ids.numpy()
+ lengths = lengths.numpy()
- ids = ctranslate2.StorageView.from_array(ids)
- lengths = ctranslate2.StorageView.from_array(lengths)
+ ids = ctranslate2.StorageView.from_array(ids)
+ lengths = ctranslate2.StorageView.from_array(lengths)
- output = generator.forward_batch(ids, lengths, return_log_probs=return_log_probs)
+ with pytest.raises(ValueError, match="lengths"):
+ generator.forward_batch(ids, **kwargs)
+ output = generator.forward_batch(ids, lengths, **kwargs)
+
+ else:
+ ids = tokenizer(text).input_ids
+ output = generator.forward_batch(ids, **kwargs)
if device == "cpu":
output = np.array(output)