diff options
Diffstat (limited to 'src/models/language_model.cc')
-rw-r--r-- | src/models/language_model.cc | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/src/models/language_model.cc b/src/models/language_model.cc index cb5718dc..3824758c 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -53,14 +53,19 @@ namespace ctranslate2 { return run_generation(start_tokens, options); } - static void check_input_location(const std::string& input_name, - const StorageView& input, - const models::Model& model) { - if (input.device() != model.device() || input.device_index() != model.device_index()) - throw std::invalid_argument("Input " + input_name - + " (" + device_to_str(input.device(), input.device_index()) - + ") must be on the same device as the model (" - + device_to_str(model.device(), model.device_index()) + ")"); + StorageView + SequenceGeneratorReplica::forward(const std::vector<std::vector<std::string>>& tokens, + const bool return_log_probs) { + const auto& vocabulary = _model->get_vocabulary(); + return forward(vocabulary.to_ids(tokens), return_log_probs); + } + + StorageView + SequenceGeneratorReplica::forward(const std::vector<std::vector<size_t>>& ids, + const bool return_log_probs) { + StorageView lengths; + StorageView input_ids = layers::make_sequence_inputs(ids, Device::CPU, 1, &lengths); + return forward(input_ids, lengths, return_log_probs); } StorageView @@ -69,12 +74,15 @@ namespace ctranslate2 { const bool return_log_probs) { PROFILE("SequenceGeneratorReplica::forward"); const auto& model = *this->model(); + const auto device = model.device(); const auto scoped_device_setter = model.get_scoped_device_setter(); - check_input_location("ids", ids, model); - check_input_location("lengths", lengths, model); + StorageView output; + if (ids.device() != device) + output = forward(ids.to(device), lengths.to(device)); + else + output = forward(ids, lengths); - StorageView output = forward(ids, lengths); if (return_log_probs) ops::LogSoftMax()(output); |