diff options
Diffstat (limited to 'python/tests/test_transformers.py')
-rw-r--r-- | python/tests/test_transformers.py | 38 |
1 files changed, 25 insertions, 13 deletions
diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 9e0a646d..baaf7ef4 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -226,7 +226,8 @@ def test_transformers_lm_scoring(tmpdir): "device", ["cpu"] + (["cuda"] if ctranslate2.get_cuda_device_count() > 0 else []) ) @pytest.mark.parametrize("return_log_probs", [True, False]) -def test_transformers_lm_forward(tmpdir, device, return_log_probs): +@pytest.mark.parametrize("tensor_input", [True, False]) +def test_transformers_lm_forward(tmpdir, device, return_log_probs, tensor_input): import torch import transformers @@ -239,29 +240,40 @@ def test_transformers_lm_forward(tmpdir, device, return_log_probs): output_dir = converter.convert(output_dir) generator = ctranslate2.Generator(output_dir, device=device) - inputs = tokenizer(["Hello world!"], return_tensors="pt") - - inputs.to(device) - model.to(device) + text = ["Hello world!"] with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + inputs.to(device) + model.to(device) output = model(**inputs) ref_output = output.logits if return_log_probs: ref_output = torch.nn.functional.log_softmax(ref_output, dim=-1) ref_output = ref_output.cpu().numpy() - ids = inputs["input_ids"].to(torch.int32) - lengths = inputs["attention_mask"].sum(1, dtype=torch.int32) + kwargs = dict(return_log_probs=return_log_probs) - if device == "cpu": - ids = ids.numpy() - lengths = lengths.numpy() + if tensor_input: + inputs = tokenizer(text, return_length=True, return_tensors="pt") + inputs.to(device) + ids = inputs.input_ids.to(torch.int32) + lengths = inputs.length.to(torch.int32) + + if device == "cpu": + ids = ids.numpy() + lengths = lengths.numpy() - ids = ctranslate2.StorageView.from_array(ids) - lengths = ctranslate2.StorageView.from_array(lengths) + ids = ctranslate2.StorageView.from_array(ids) + lengths = ctranslate2.StorageView.from_array(lengths) - output = generator.forward_batch(ids, lengths, return_log_probs=return_log_probs) + with pytest.raises(ValueError, match="lengths"): + generator.forward_batch(ids, **kwargs) + output = generator.forward_batch(ids, lengths, **kwargs) + + else: + ids = tokenizer(text).input_ids + output = generator.forward_batch(ids, **kwargs) if device == "cpu": output = np.array(output) |