diff options
author | Guillaume Klein <guillaume.klein@systrangroup.com> | 2018-07-18 15:23:39 +0300 |
---|---|---|
committer | Guillaume Klein <guillaume.klein@systrangroup.com> | 2018-07-18 15:23:39 +0300 |
commit | 12e4190823e349cbe6202351b77a943ff8bcd102 (patch) | |
tree | ad0b799ac509f76ab8458d2196279d467eb9984b /onmt | |
parent | b16f73fbf8f4c0edd59247cc04f9c999cd3f9dfa (diff) |
Fix condition to detect LM training
Fixes #552
Diffstat (limited to 'onmt')
-rw-r--r-- | onmt/train/Trainer.lua | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/onmt/train/Trainer.lua b/onmt/train/Trainer.lua index 94d0f0cc..ab764eeb 100644 --- a/onmt/train/Trainer.lua +++ b/onmt/train/Trainer.lua @@ -156,7 +156,7 @@ function Trainer:__init(args, model, dicts, trainDataset) for fi = 1, #firstBatch.sourceInputFeatures do sfeat[fi] = torch.LongTensor(src_sentmax):fill(onmt.Constants.UNK) end - if firstBatch.tgt ~= nil then + if firstBatch.targetInput ~= nil then tgt = {} tgtFeats = {} tfeat = tds.Vec(#firstBatch.targetInputFeatures) @@ -170,7 +170,7 @@ function Trainer:__init(args, model, dicts, trainDataset) if #firstBatch.sourceInputFeatures > 0 then table.insert(srcFeats, sfeat) end - if firstBatch.tgt ~= nil then + if firstBatch.targetInput ~= nil then table.insert(tgt, torch.LongTensor(trainDataset.maxTargetLength):fill(onmt.Constants.UNK)) if #firstBatch.targetInputFeatures > 0 then table.insert(tgtFeats, tfeat) |