diff options
Diffstat (limited to 'include/ctranslate2/models/language_model.h')
-rw-r--r-- | include/ctranslate2/models/language_model.h | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index a1dd55b2..e0bc0a8e 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -25,8 +25,9 @@ namespace ctranslate2 { // Base class for generative language models. class SequenceGeneratorReplica : public ModelReplica { public: - SequenceGeneratorReplica(const std::shared_ptr<const Model>& model) + SequenceGeneratorReplica(const std::shared_ptr<const LanguageModel>& model) : ModelReplica(model) + , _model(model) { } @@ -42,6 +43,10 @@ namespace ctranslate2 { generate(const std::vector<std::vector<std::string>>& start_tokens, const GenerationOptions& options = GenerationOptions()); + StorageView forward(const std::vector<std::vector<std::string>>& tokens, + const bool return_log_probs); + StorageView forward(const std::vector<std::vector<size_t>>& ids, + const bool return_log_probs); StorageView forward(const StorageView& ids, const StorageView& lengths, const bool return_log_probs); @@ -65,6 +70,9 @@ namespace ctranslate2 { const GenerationOptions& options) = 0; virtual StorageView forward(const StorageView& ids, const StorageView& lengths) = 0; + + private: + const std::shared_ptr<const LanguageModel> _model; }; |