Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2022-10-28 18:51:02 +0300
committerGitHub <noreply@github.com>2022-10-28 18:51:02 +0300
commit771bc02fd60bc40d1d50fbdfe0fabe2f0aac7564 (patch)
tree263207b9407f791f0493c14470e41d91f8d11c43
parent145e488bb73168400fe0327e470abfdfd9dc8f18 (diff)
Accept more input types in forward_batch method (#955)
-rw-r--r--include/ctranslate2/generator_pool.h8
-rw-r--r--include/ctranslate2/models/language_model.h10
-rw-r--r--python/cpp/generator.cc45
-rw-r--r--python/cpp/utils.h1
-rw-r--r--python/tests/test_transformers.py38
-rw-r--r--src/generator_pool.cc20
-rw-r--r--src/models/language_model.cc30
7 files changed, 111 insertions, 41 deletions
diff --git a/include/ctranslate2/generator_pool.h b/include/ctranslate2/generator_pool.h
index 3bd33924..d450925e 100644
--- a/include/ctranslate2/generator_pool.h
+++ b/include/ctranslate2/generator_pool.h
@@ -30,6 +30,14 @@ namespace ctranslate2 {
const BatchType batch_type = BatchType::Examples);
std::future<StorageView>
+ forward_batch_async(std::vector<std::vector<std::string>> tokens,
+ const bool return_log_probs);
+
+ std::future<StorageView>
+ forward_batch_async(std::vector<std::vector<size_t>> ids,
+ const bool return_log_probs);
+
+ std::future<StorageView>
forward_batch_async(StorageView ids,
StorageView lengths,
const bool return_log_probs);
diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h
index a1dd55b2..e0bc0a8e 100644
--- a/include/ctranslate2/models/language_model.h
+++ b/include/ctranslate2/models/language_model.h
@@ -25,8 +25,9 @@ namespace ctranslate2 {
// Base class for generative language models.
class SequenceGeneratorReplica : public ModelReplica {
public:
- SequenceGeneratorReplica(const std::shared_ptr<const Model>& model)
+ SequenceGeneratorReplica(const std::shared_ptr<const LanguageModel>& model)
: ModelReplica(model)
+ , _model(model)
{
}
@@ -42,6 +43,10 @@ namespace ctranslate2 {
generate(const std::vector<std::vector<std::string>>& start_tokens,
const GenerationOptions& options = GenerationOptions());
+ StorageView forward(const std::vector<std::vector<std::string>>& tokens,
+ const bool return_log_probs);
+ StorageView forward(const std::vector<std::vector<size_t>>& ids,
+ const bool return_log_probs);
StorageView forward(const StorageView& ids,
const StorageView& lengths,
const bool return_log_probs);
@@ -65,6 +70,9 @@ namespace ctranslate2 {
const GenerationOptions& options) = 0;
virtual StorageView forward(const StorageView& ids, const StorageView& lengths) = 0;
+
+ private:
+ const std::shared_ptr<const LanguageModel> _model;
};
diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc
index 2e20abb1..74d5bb66 100644
--- a/python/cpp/generator.cc
+++ b/python/cpp/generator.cc
@@ -110,12 +110,27 @@ namespace ctranslate2 {
return maybe_wait_on_futures(std::move(futures), asynchronous);
}
- StorageViewWrapper forward_batch(const StorageViewWrapper& ids,
- const StorageViewWrapper& lengths,
- const bool return_log_probs) {
- auto future = _generator_pool.forward_batch_async(ids.get_view(),
- lengths.get_view(),
- return_log_probs);
+ StorageViewWrapper
+ forward_batch(const std::variant<BatchTokens, BatchIds, StorageViewWrapper>& inputs,
+ const std::optional<StorageViewWrapper>& lengths,
+ const bool return_log_probs) {
+ std::future<StorageView> future;
+
+ switch (inputs.index()) {
+ case 0:
+ future = _generator_pool.forward_batch_async(std::get<BatchTokens>(inputs), return_log_probs);
+ break;
+ case 1:
+ future = _generator_pool.forward_batch_async(std::get<BatchIds>(inputs), return_log_probs);
+ break;
+ case 2:
+ if (!lengths)
+ throw std::invalid_argument("lengths vector is required when passing a dense input");
+ const StorageView& ids_view = std::get<StorageViewWrapper>(inputs).get_view();
+ const StorageView& lengths_view = lengths.value().get_view();
+ future = _generator_pool.forward_batch_async(ids_view, lengths_view, return_log_probs);
+ break;
+ }
return StorageViewWrapper(future.get());
}
@@ -262,29 +277,27 @@ namespace ctranslate2 {
A list of scoring results.
)pbdoc")
- .def("forward_batch", &GeneratorWrapper::forward_batch,
- py::arg("ids"),
- py::arg("lengths"),
+ .def("forward_batch", &GeneratorWrapper::forward_batch,
+ py::arg("inputs"),
+ py::arg("lengths")=py::none(),
py::kw_only(),
py::arg("return_log_probs")=false,
py::call_guard<py::gil_scoped_release>(),
R"pbdoc(
- Forwards a batch of token IDs in the generator.
+ Forwards a batch of sequences in the generator.
Arguments:
- ids: The sequences of token IDs as a int32 array with shape
- ``[batch_size, max_length]``.
+ inputs: A batch of sequences either as string tokens or token IDs.
+ This argument can also be a dense int32 array with shape
+ ``[batch_size, max_length]`` (e.g. created from a Numpy array or PyTorch tensor).
lengths: The length of each sequence as a int32 array with shape
- ``[batch_size]``.
+ ``[batch_size]``. Required when :obj:`inputs` is a dense array.
return_log_probs: If ``True``, the method returns the log probabilties instead
of the unscaled logits.
Returns:
The output logits, or the output log probabilities if :obj:`return_log_probs`
is enabled.
-
- Note:
- :obj:`ids` and :obj:`lengths` must be on the same device as the model.
)pbdoc")
;
}
diff --git a/python/cpp/utils.h b/python/cpp/utils.h
index 1b563ae4..f8616311 100644
--- a/python/cpp/utils.h
+++ b/python/cpp/utils.h
@@ -21,6 +21,7 @@ namespace ctranslate2 {
using StringOrMap = std::variant<std::string, std::unordered_map<std::string, std::string>>;
using Tokens = std::vector<std::string>;
using BatchTokens = std::vector<Tokens>;
+ using BatchIds = std::vector<std::vector<size_t>>;
class ComputeTypeResolver {
private:
diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py
index 9e0a646d..baaf7ef4 100644
--- a/python/tests/test_transformers.py
+++ b/python/tests/test_transformers.py
@@ -226,7 +226,8 @@ def test_transformers_lm_scoring(tmpdir):
"device", ["cpu"] + (["cuda"] if ctranslate2.get_cuda_device_count() > 0 else [])
)
@pytest.mark.parametrize("return_log_probs", [True, False])
-def test_transformers_lm_forward(tmpdir, device, return_log_probs):
+@pytest.mark.parametrize("tensor_input", [True, False])
+def test_transformers_lm_forward(tmpdir, device, return_log_probs, tensor_input):
import torch
import transformers
@@ -239,29 +240,40 @@ def test_transformers_lm_forward(tmpdir, device, return_log_probs):
output_dir = converter.convert(output_dir)
generator = ctranslate2.Generator(output_dir, device=device)
- inputs = tokenizer(["Hello world!"], return_tensors="pt")
-
- inputs.to(device)
- model.to(device)
+ text = ["Hello world!"]
with torch.no_grad():
+ inputs = tokenizer(text, return_tensors="pt")
+ inputs.to(device)
+ model.to(device)
output = model(**inputs)
ref_output = output.logits
if return_log_probs:
ref_output = torch.nn.functional.log_softmax(ref_output, dim=-1)
ref_output = ref_output.cpu().numpy()
- ids = inputs["input_ids"].to(torch.int32)
- lengths = inputs["attention_mask"].sum(1, dtype=torch.int32)
+ kwargs = dict(return_log_probs=return_log_probs)
- if device == "cpu":
- ids = ids.numpy()
- lengths = lengths.numpy()
+ if tensor_input:
+ inputs = tokenizer(text, return_length=True, return_tensors="pt")
+ inputs.to(device)
+ ids = inputs.input_ids.to(torch.int32)
+ lengths = inputs.length.to(torch.int32)
+
+ if device == "cpu":
+ ids = ids.numpy()
+ lengths = lengths.numpy()
- ids = ctranslate2.StorageView.from_array(ids)
- lengths = ctranslate2.StorageView.from_array(lengths)
+ ids = ctranslate2.StorageView.from_array(ids)
+ lengths = ctranslate2.StorageView.from_array(lengths)
- output = generator.forward_batch(ids, lengths, return_log_probs=return_log_probs)
+ with pytest.raises(ValueError, match="lengths"):
+ generator.forward_batch(ids, **kwargs)
+ output = generator.forward_batch(ids, lengths, **kwargs)
+
+ else:
+ ids = tokenizer(text).input_ids
+ output = generator.forward_batch(ids, **kwargs)
if device == "cpu":
output = np.array(output)
diff --git a/src/generator_pool.cc b/src/generator_pool.cc
index f55c1157..63f27807 100644
--- a/src/generator_pool.cc
+++ b/src/generator_pool.cc
@@ -56,6 +56,26 @@ namespace ctranslate2 {
}
std::future<StorageView>
+ GeneratorPool::forward_batch_async(std::vector<std::vector<std::string>> tokens,
+ const bool return_log_probs) {
+ return post<StorageView>(
+ [tokens = std::move(tokens), return_log_probs]
+ (models::SequenceGeneratorReplica& generator) {
+ return generator.forward(tokens, return_log_probs);
+ });
+ }
+
+ std::future<StorageView>
+ GeneratorPool::forward_batch_async(std::vector<std::vector<size_t>> ids,
+ const bool return_log_probs) {
+ return post<StorageView>(
+ [ids = std::move(ids), return_log_probs]
+ (models::SequenceGeneratorReplica& generator) {
+ return generator.forward(ids, return_log_probs);
+ });
+ }
+
+ std::future<StorageView>
GeneratorPool::forward_batch_async(StorageView ids,
StorageView lengths,
const bool return_log_probs) {
diff --git a/src/models/language_model.cc b/src/models/language_model.cc
index cb5718dc..3824758c 100644
--- a/src/models/language_model.cc
+++ b/src/models/language_model.cc
@@ -53,14 +53,19 @@ namespace ctranslate2 {
return run_generation(start_tokens, options);
}
- static void check_input_location(const std::string& input_name,
- const StorageView& input,
- const models::Model& model) {
- if (input.device() != model.device() || input.device_index() != model.device_index())
- throw std::invalid_argument("Input " + input_name
- + " (" + device_to_str(input.device(), input.device_index())
- + ") must be on the same device as the model ("
- + device_to_str(model.device(), model.device_index()) + ")");
+ StorageView
+ SequenceGeneratorReplica::forward(const std::vector<std::vector<std::string>>& tokens,
+ const bool return_log_probs) {
+ const auto& vocabulary = _model->get_vocabulary();
+ return forward(vocabulary.to_ids(tokens), return_log_probs);
+ }
+
+ StorageView
+ SequenceGeneratorReplica::forward(const std::vector<std::vector<size_t>>& ids,
+ const bool return_log_probs) {
+ StorageView lengths;
+ StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths);
+ return forward(input_ids, lengths, return_log_probs);
}
StorageView
@@ -69,12 +74,15 @@ namespace ctranslate2 {
const bool return_log_probs) {
PROFILE("SequenceGeneratorReplica::forward");
const auto& model = *this->model();
+ const auto device = model.device();
const auto scoped_device_setter = model.get_scoped_device_setter();
- check_input_location("ids", ids, model);
- check_input_location("lengths", lengths, model);
+ StorageView output;
+ if (ids.device() != device)
+ output = forward(ids.to(device), lengths.to(device));
+ else
+ output = forward(ids, lengths);
- StorageView output = forward(ids, lengths);
if (return_log_probs)
ops::LogSoftMax()(output);