Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/models/language_model.cc')
-rw-r--r--src/models/language_model.cc30
1 files changed, 19 insertions, 11 deletions
diff --git a/src/models/language_model.cc b/src/models/language_model.cc
index cb5718dc..3824758c 100644
--- a/src/models/language_model.cc
+++ b/src/models/language_model.cc
@@ -53,14 +53,19 @@ namespace ctranslate2 {
return run_generation(start_tokens, options);
}
- static void check_input_location(const std::string& input_name,
- const StorageView& input,
- const models::Model& model) {
- if (input.device() != model.device() || input.device_index() != model.device_index())
- throw std::invalid_argument("Input " + input_name
- + " (" + device_to_str(input.device(), input.device_index())
- + ") must be on the same device as the model ("
- + device_to_str(model.device(), model.device_index()) + ")");
+ StorageView
+ SequenceGeneratorReplica::forward(const std::vector<std::vector<std::string>>& tokens,
+ const bool return_log_probs) {
+ const auto& vocabulary = _model->get_vocabulary();
+ return forward(vocabulary.to_ids(tokens), return_log_probs);
+ }
+
+ StorageView
+ SequenceGeneratorReplica::forward(const std::vector<std::vector<size_t>>& ids,
+ const bool return_log_probs) {
+ StorageView lengths;
+ StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths);
+ return forward(input_ids, lengths, return_log_probs);
}
StorageView
@@ -69,12 +74,15 @@ namespace ctranslate2 {
const bool return_log_probs) {
PROFILE("SequenceGeneratorReplica::forward");
const auto& model = *this->model();
+ const auto device = model.device();
const auto scoped_device_setter = model.get_scoped_device_setter();
- check_input_location("ids", ids, model);
- check_input_location("lengths", lengths, model);
+ StorageView output;
+ if (ids.device() != device)
+ output = forward(ids.to(device), lengths.to(device));
+ else
+ output = forward(ids, lengths);
- StorageView output = forward(ids, lengths);
if (return_log_probs)
ops::LogSoftMax()(output);