Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/onmt
diff options
context:
space:
mode:
authorGuillaume Klein <guillaume.klein@systrangroup.com>2018-07-18 15:23:39 +0300
committerGuillaume Klein <guillaume.klein@systrangroup.com>2018-07-18 15:23:39 +0300
commit12e4190823e349cbe6202351b77a943ff8bcd102 (patch)
treead0b799ac509f76ab8458d2196279d467eb9984b /onmt
parentb16f73fbf8f4c0edd59247cc04f9c999cd3f9dfa (diff)
Fix condition to detect LM training
Fixes #552
Diffstat (limited to 'onmt')
-rw-r--r--onmt/train/Trainer.lua4
1 files changed, 2 insertions, 2 deletions
diff --git a/onmt/train/Trainer.lua b/onmt/train/Trainer.lua
index 94d0f0cc..ab764eeb 100644
--- a/onmt/train/Trainer.lua
+++ b/onmt/train/Trainer.lua
@@ -156,7 +156,7 @@ function Trainer:__init(args, model, dicts, trainDataset)
for fi = 1, #firstBatch.sourceInputFeatures do
sfeat[fi] = torch.LongTensor(src_sentmax):fill(onmt.Constants.UNK)
end
- if firstBatch.tgt ~= nil then
+ if firstBatch.targetInput ~= nil then
tgt = {}
tgtFeats = {}
tfeat = tds.Vec(#firstBatch.targetInputFeatures)
@@ -170,7 +170,7 @@ function Trainer:__init(args, model, dicts, trainDataset)
if #firstBatch.sourceInputFeatures > 0 then
table.insert(srcFeats, sfeat)
end
- if firstBatch.tgt ~= nil then
+ if firstBatch.targetInput ~= nil then
table.insert(tgt, torch.LongTensor(trainDataset.maxTargetLength):fill(onmt.Constants.UNK))
if #firstBatch.targetInputFeatures > 0 then
table.insert(tgtFeats, tfeat)