Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'include/ctranslate2/models/language_model.h')
-rw-r--r--include/ctranslate2/models/language_model.h10
1 files changed, 9 insertions, 1 deletions
diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h
index a1dd55b2..e0bc0a8e 100644
--- a/include/ctranslate2/models/language_model.h
+++ b/include/ctranslate2/models/language_model.h
@@ -25,8 +25,9 @@ namespace ctranslate2 {
// Base class for generative language models.
class SequenceGeneratorReplica : public ModelReplica {
public:
- SequenceGeneratorReplica(const std::shared_ptr<const Model>& model)
+ SequenceGeneratorReplica(const std::shared_ptr<const LanguageModel>& model)
: ModelReplica(model)
+ , _model(model)
{
}
@@ -42,6 +43,10 @@ namespace ctranslate2 {
generate(const std::vector<std::vector<std::string>>& start_tokens,
const GenerationOptions& options = GenerationOptions());
+ StorageView forward(const std::vector<std::vector<std::string>>& tokens,
+ const bool return_log_probs);
+ StorageView forward(const std::vector<std::vector<size_t>>& ids,
+ const bool return_log_probs);
StorageView forward(const StorageView& ids,
const StorageView& lengths,
const bool return_log_probs);
@@ -65,6 +70,9 @@ namespace ctranslate2 {
const GenerationOptions& options) = 0;
virtual StorageView forward(const StorageView& ids, const StorageView& lengths) = 0;
+
+ private:
+ const std::shared_ptr<const LanguageModel> _model;
};