From 771bc02fd60bc40d1d50fbdfe0fabe2f0aac7564 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Fri, 28 Oct 2022 17:51:02 +0200 Subject: Accept more input types in forward_batch method (#955) --- include/ctranslate2/generator_pool.h | 8 +++++ include/ctranslate2/models/language_model.h | 10 ++++++- python/cpp/generator.cc | 45 +++++++++++++++++++---------- python/cpp/utils.h | 1 + python/tests/test_transformers.py | 38 +++++++++++++++--------- src/generator_pool.cc | 20 +++++++++++++ src/models/language_model.cc | 30 ++++++++++++------- 7 files changed, 111 insertions(+), 41 deletions(-) diff --git a/include/ctranslate2/generator_pool.h b/include/ctranslate2/generator_pool.h index 3bd33924..d450925e 100644 --- a/include/ctranslate2/generator_pool.h +++ b/include/ctranslate2/generator_pool.h @@ -29,6 +29,14 @@ namespace ctranslate2 { const size_t max_batch_size = 0, const BatchType batch_type = BatchType::Examples); + std::future + forward_batch_async(std::vector> tokens, + const bool return_log_probs); + + std::future + forward_batch_async(std::vector> ids, + const bool return_log_probs); + std::future forward_batch_async(StorageView ids, StorageView lengths, diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index a1dd55b2..e0bc0a8e 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -25,8 +25,9 @@ namespace ctranslate2 { // Base class for generative language models. class SequenceGeneratorReplica : public ModelReplica { public: - SequenceGeneratorReplica(const std::shared_ptr& model) + SequenceGeneratorReplica(const std::shared_ptr& model) : ModelReplica(model) + , _model(model) { } @@ -42,6 +43,10 @@ namespace ctranslate2 { generate(const std::vector>& start_tokens, const GenerationOptions& options = GenerationOptions()); + StorageView forward(const std::vector>& tokens, + const bool return_log_probs); + StorageView forward(const std::vector>& ids, + const bool return_log_probs); StorageView forward(const StorageView& ids, const StorageView& lengths, const bool return_log_probs); @@ -65,6 +70,9 @@ namespace ctranslate2 { const GenerationOptions& options) = 0; virtual StorageView forward(const StorageView& ids, const StorageView& lengths) = 0; + + private: + const std::shared_ptr _model; }; diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 2e20abb1..74d5bb66 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -110,12 +110,27 @@ namespace ctranslate2 { return maybe_wait_on_futures(std::move(futures), asynchronous); } - StorageViewWrapper forward_batch(const StorageViewWrapper& ids, - const StorageViewWrapper& lengths, - const bool return_log_probs) { - auto future = _generator_pool.forward_batch_async(ids.get_view(), - lengths.get_view(), - return_log_probs); + StorageViewWrapper + forward_batch(const std::variant& inputs, + const std::optional& lengths, + const bool return_log_probs) { + std::future future; + + switch (inputs.index()) { + case 0: + future = _generator_pool.forward_batch_async(std::get(inputs), return_log_probs); + break; + case 1: + future = _generator_pool.forward_batch_async(std::get(inputs), return_log_probs); + break; + case 2: + if (!lengths) + throw std::invalid_argument("lengths vector is required when passing a dense input"); + const StorageView& ids_view = std::get(inputs).get_view(); + const StorageView& lengths_view = lengths.value().get_view(); + future = _generator_pool.forward_batch_async(ids_view, lengths_view, return_log_probs); + break; + } return StorageViewWrapper(future.get()); } @@ -262,29 +277,27 @@ namespace ctranslate2 { A list of scoring results. )pbdoc") - .def("forward_batch", &GeneratorWrapper::forward_batch, - py::arg("ids"), - py::arg("lengths"), + .def("forward_batch", &GeneratorWrapper::forward_batch, + py::arg("inputs"), + py::arg("lengths")=py::none(), py::kw_only(), py::arg("return_log_probs")=false, py::call_guard(), R"pbdoc( - Forwards a batch of token IDs in the generator. + Forwards a batch of sequences in the generator. Arguments: - ids: The sequences of token IDs as a int32 array with shape - ``[batch_size, max_length]``. + inputs: A batch of sequences either as string tokens or token IDs. + This argument can also be a dense int32 array with shape + ``[batch_size, max_length]`` (e.g. created from a Numpy array or PyTorch tensor). lengths: The length of each sequence as a int32 array with shape - ``[batch_size]``. + ``[batch_size]``. Required when :obj:`inputs` is a dense array. return_log_probs: If ``True``, the method returns the log probabilties instead of the unscaled logits. Returns: The output logits, or the output log probabilities if :obj:`return_log_probs` is enabled. - - Note: - :obj:`ids` and :obj:`lengths` must be on the same device as the model. )pbdoc") ; } diff --git a/python/cpp/utils.h b/python/cpp/utils.h index 1b563ae4..f8616311 100644 --- a/python/cpp/utils.h +++ b/python/cpp/utils.h @@ -21,6 +21,7 @@ namespace ctranslate2 { using StringOrMap = std::variant>; using Tokens = std::vector; using BatchTokens = std::vector; + using BatchIds = std::vector>; class ComputeTypeResolver { private: diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 9e0a646d..baaf7ef4 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -226,7 +226,8 @@ def test_transformers_lm_scoring(tmpdir): "device", ["cpu"] + (["cuda"] if ctranslate2.get_cuda_device_count() > 0 else []) ) @pytest.mark.parametrize("return_log_probs", [True, False]) -def test_transformers_lm_forward(tmpdir, device, return_log_probs): +@pytest.mark.parametrize("tensor_input", [True, False]) +def test_transformers_lm_forward(tmpdir, device, return_log_probs, tensor_input): import torch import transformers @@ -239,29 +240,40 @@ def test_transformers_lm_forward(tmpdir, device, return_log_probs): output_dir = converter.convert(output_dir) generator = ctranslate2.Generator(output_dir, device=device) - inputs = tokenizer(["Hello world!"], return_tensors="pt") - - inputs.to(device) - model.to(device) + text = ["Hello world!"] with torch.no_grad(): + inputs = tokenizer(text, return_tensors="pt") + inputs.to(device) + model.to(device) output = model(**inputs) ref_output = output.logits if return_log_probs: ref_output = torch.nn.functional.log_softmax(ref_output, dim=-1) ref_output = ref_output.cpu().numpy() - ids = inputs["input_ids"].to(torch.int32) - lengths = inputs["attention_mask"].sum(1, dtype=torch.int32) + kwargs = dict(return_log_probs=return_log_probs) - if device == "cpu": - ids = ids.numpy() - lengths = lengths.numpy() + if tensor_input: + inputs = tokenizer(text, return_length=True, return_tensors="pt") + inputs.to(device) + ids = inputs.input_ids.to(torch.int32) + lengths = inputs.length.to(torch.int32) + + if device == "cpu": + ids = ids.numpy() + lengths = lengths.numpy() - ids = ctranslate2.StorageView.from_array(ids) - lengths = ctranslate2.StorageView.from_array(lengths) + ids = ctranslate2.StorageView.from_array(ids) + lengths = ctranslate2.StorageView.from_array(lengths) - output = generator.forward_batch(ids, lengths, return_log_probs=return_log_probs) + with pytest.raises(ValueError, match="lengths"): + generator.forward_batch(ids, **kwargs) + output = generator.forward_batch(ids, lengths, **kwargs) + + else: + ids = tokenizer(text).input_ids + output = generator.forward_batch(ids, **kwargs) if device == "cpu": output = np.array(output) diff --git a/src/generator_pool.cc b/src/generator_pool.cc index f55c1157..63f27807 100644 --- a/src/generator_pool.cc +++ b/src/generator_pool.cc @@ -55,6 +55,26 @@ namespace ctranslate2 { }); } + std::future + GeneratorPool::forward_batch_async(std::vector> tokens, + const bool return_log_probs) { + return post( + [tokens = std::move(tokens), return_log_probs] + (models::SequenceGeneratorReplica& generator) { + return generator.forward(tokens, return_log_probs); + }); + } + + std::future + GeneratorPool::forward_batch_async(std::vector> ids, + const bool return_log_probs) { + return post( + [ids = std::move(ids), return_log_probs] + (models::SequenceGeneratorReplica& generator) { + return generator.forward(ids, return_log_probs); + }); + } + std::future GeneratorPool::forward_batch_async(StorageView ids, StorageView lengths, diff --git a/src/models/language_model.cc b/src/models/language_model.cc index cb5718dc..3824758c 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -53,14 +53,19 @@ namespace ctranslate2 { return run_generation(start_tokens, options); } - static void check_input_location(const std::string& input_name, - const StorageView& input, - const models::Model& model) { - if (input.device() != model.device() || input.device_index() != model.device_index()) - throw std::invalid_argument("Input " + input_name - + " (" + device_to_str(input.device(), input.device_index()) - + ") must be on the same device as the model (" - + device_to_str(model.device(), model.device_index()) + ")"); + StorageView + SequenceGeneratorReplica::forward(const std::vector>& tokens, + const bool return_log_probs) { + const auto& vocabulary = _model->get_vocabulary(); + return forward(vocabulary.to_ids(tokens), return_log_probs); + } + + StorageView + SequenceGeneratorReplica::forward(const std::vector>& ids, + const bool return_log_probs) { + StorageView lengths; + StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths); + return forward(input_ids, lengths, return_log_probs); } StorageView @@ -69,12 +74,15 @@ namespace ctranslate2 { const bool return_log_probs) { PROFILE("SequenceGeneratorReplica::forward"); const auto& model = *this->model(); + const auto device = model.device(); const auto scoped_device_setter = model.get_scoped_device_setter(); - check_input_location("ids", ids, model); - check_input_location("lengths", lengths, model); + StorageView output; + if (ids.device() != device) + output = forward(ids.to(device), lengths.to(device)); + else + output = forward(ids, lengths); - StorageView output = forward(ids, lengths); if (return_log_probs) ops::LogSoftMax()(output); -- cgit v1.2.3