diff options
author | Vincent Nguyen <vince62s@yahoo.com> | 2018-03-05 23:25:18 +0300 |
---|---|---|
committer | Jean Senellart <jean@senellart.com> | 2018-03-05 23:25:18 +0300 |
commit | 560a5752f3948818f88c19f0265b3e4fc32969f7 (patch) | |
tree | 14dd8fc28f8120396ec423eed22072e88d03ff35 /onmt | |
parent | 5295476b19bd0f8c3ae81b9504770b8b7708763e (diff) |
WIP introduce token batch size limit (#506)
* introduce token batch size limit
* add memory preallocation so that varying batch size do not cause memory fragmentation.
* update changelog
* extension to sampled dataset
* further optimization by considering only max observed max source and target sentences.
Diffstat (limited to 'onmt')
-rw-r--r-- | onmt/data/Dataset.lua | 7 | ||||
-rw-r--r-- | onmt/data/DynamicDataset.lua | 7 | ||||
-rw-r--r-- | onmt/data/SampledDataset.lua | 7 | ||||
-rw-r--r-- | onmt/train/Optim.lua | 9 | ||||
-rw-r--r-- | onmt/train/Trainer.lua | 45 |
5 files changed, 63 insertions, 12 deletions
diff --git a/onmt/data/Dataset.lua b/onmt/data/Dataset.lua index 49f9cdc3..846dabd5 100644 --- a/onmt/data/Dataset.lua +++ b/onmt/data/Dataset.lua @@ -18,7 +18,7 @@ end --[[ Setup up the training data to respect `maxBatchSize`. If uneven_batches - then build up batches with different lengths ]] -function Dataset:setBatchSize(maxBatchSize, uneven_batches) +function Dataset:setBatchSize(maxBatchSize, maxTokens, uneven_batches) self.batchRange = {} self.maxSourceLength = 0 @@ -32,16 +32,19 @@ function Dataset:setBatchSize(maxBatchSize, uneven_batches) local batchSize = 1 local maxSourceLength = 0 local targetLength = 0 + local TokensInBatch = 0 for i = 1, #self.src do -- Set up the offsets to make same source size batches of the -- correct size. local sourceLength = self.src[i]:size(1) - if batchSize == maxBatchSize or i == 1 or + TokensInBatch = TokensInBatch + sourceLength + if TokensInBatch > maxTokens or i==1 or batchSize == maxBatchSize or (not(uneven_batches) and self.src[i]:size(1) ~= maxSourceLength) then if i > 1 then batchesCapacity = batchesCapacity + batchSize * maxSourceLength table.insert(self.batchRange, { ["begin"] = offset, ["end"] = i - 1 }) + TokensInBatch = sourceLength end offset = i diff --git a/onmt/data/DynamicDataset.lua b/onmt/data/DynamicDataset.lua index cb0ef8f9..69dcd7df 100644 --- a/onmt/data/DynamicDataset.lua +++ b/onmt/data/DynamicDataset.lua @@ -8,7 +8,7 @@ function DynamicDataset:__init(opt, ddr) end --[[ define batch size ]] -function DynamicDataset:setBatchSize(maxBatchSize, uneven_batches) +function DynamicDataset:setBatchSize(maxBatchSize, maxTokens, uneven_batches) -- time to build first sample local data = self.ddr.preprocessor:makeData('train', self.ddr.dicts) self.first = true @@ -18,8 +18,9 @@ function DynamicDataset:setBatchSize(maxBatchSize, uneven_batches) self.dataset = onmt.data.Dataset.new(data.src, data.tgt) end self.maxBatchSize = maxBatchSize + self.maxTokens = maxTokens self.uneven_batches = uneven_batches - local nTrainBatch, batchUsage = self.dataset:setBatchSize(maxBatchSize, uneven_batches) + local nTrainBatch, batchUsage = self.dataset:setBatchSize(maxBatchSize, maxTokens, uneven_batches) self.src = self.dataset.src self.tgt = self.dataset.tgt self.maxSourceLength = self.dataset.maxSourceLength @@ -42,7 +43,7 @@ function DynamicDataset:sample() self.src = self.dataset.src self.tgt = self.dataset.tgt self:sampleVocabInit(self.opt, self.src, self.tgt) - local nTrainBatch, _ = self.dataset:setBatchSize(self.maxBatchSize, self.uneven_batches) + local nTrainBatch, _ = self.dataset:setBatchSize(self.maxBatchSize, self.maxTokens, self.uneven_batches) self.maxSourceLength = self.dataset.maxSourceLength self.maxTargetLength = self.dataset.maxTargetLength _G.logger:info('Sampling completed - %d sentences, %d mini-batch', #self.src, nTrainBatch) diff --git a/onmt/data/SampledDataset.lua b/onmt/data/SampledDataset.lua index 6aac90bb..cdfc3dd7 100644 --- a/onmt/data/SampledDataset.lua +++ b/onmt/data/SampledDataset.lua @@ -74,6 +74,7 @@ function SampledDataset:__init(opt, srcData, tgtData) self.samplingProb = torch.ones(#self.src) end + self.maxTokens = opt.max_tokens self.sampled = nil self.sampledCnt = torch.zeros(#self.src) end @@ -212,10 +213,13 @@ function SampledDataset:sample(logLevel) local sampleCntBegin = 1 local batchSize = 1 local maxSourceLength = -1 + local TokensInBatch = 0 + for i = 1, #self.src do for j = 1, self.sampledCnt[i] do local sourceLength = self.src[i]:size(1) - if batchSize == self.maxBatchSize or offset == 1 or + TokensInBatch = TokensInBatch + sourceLength + if TokensInBatch > self.maxTokens or batchSize == self.maxBatchSize or offset == 1 or (not(self.uneven_batches) and self.src[i]:size(1) ~= maxSourceLength) then if offset > 0 then batchesCapacity = batchesCapacity + batchSize * maxSourceLength @@ -228,6 +232,7 @@ function SampledDataset:sample(logLevel) ["sampleCntEnd"] = sampleCntEnd }) sampleCntBegin = (j == 1) and 1 or j + TokensInBatch = sourceLength end offset = i batchSize = 1 diff --git a/onmt/train/Optim.lua b/onmt/train/Optim.lua index b0faae84..96109d7d 100644 --- a/onmt/train/Optim.lua +++ b/onmt/train/Optim.lua @@ -2,13 +2,20 @@ local Optim = torch.class('Optim') local options = { { - '-max_batch_size', 64, + '-max_batch_size', 160, [[Maximum batch size.]], { valid = onmt.utils.ExtendedCmdLine.isUInt() } }, { + '-max_tokens', 1800, + [[Maximum tokens in a batch.]], + { + valid = onmt.utils.ExtendedCmdLine.isUInt() + } + }, + { '-uneven_batches', false, [[If set, batches are filled up to `-max_batch_size` even if the source lengths are different. Slower but needed for some tasks.]] diff --git a/onmt/train/Trainer.lua b/onmt/train/Trainer.lua index 94814ad0..72beaa83 100644 --- a/onmt/train/Trainer.lua +++ b/onmt/train/Trainer.lua @@ -1,3 +1,5 @@ +tds = require 'tds' + local Trainer = torch.class('Trainer') local options = { @@ -94,7 +96,7 @@ function Trainer.declareOpts(cmd) onmt.translate.Translator.declareOpts(cmd) end -function Trainer:__init(args, model, dicts, firstBatch) +function Trainer:__init(args, model, dicts, trainDataset) self.args = onmt.utils.ExtendedCmdLine.getModuleOpts(args, options) self.args.profiler = args.profiler self.args.disable_mem_optimization = args.disable_mem_optimization @@ -134,11 +136,44 @@ function Trainer:__init(args, model, dicts, firstBatch) -- If enabled, share internal buffers to optimize for memory. if not self.args.disable_mem_optimization then - if not firstBatch then - _G.logger:error('A first batch is needed to optimize the computation graph for memory') - else - onmt.utils.Memory.optimize(model, onmt.utils.Cuda.convert(firstBatch)) + local firstBatch = trainDataset:getBatch(1) + onmt.utils.Memory.optimize(model, onmt.utils.Cuda.convert(firstBatch)) + + _G.logger:info('Preallocating memory') + -- preallocate memory in encoder and decoder with highest batch size, and longest sentence + -- for target, we take args.tgt_seq_length as the longest possible (worse case). This could be refined + -- since in batch building, we can calculate tgt_max_tokens + local src_sentmax = math.min(trainDataset.maxSourceLength, math.ceil(args.max_tokens/args.max_batch_size)) + + local src = {} + local srcFeats = {} + local tgt = {} + local tgtFeats = {} + local sfeat = tds.Vec(#firstBatch.sourceInputFeatures) + for fi = 1, #firstBatch.sourceInputFeatures do + sfeat[fi] = torch.LongTensor(src_sentmax):fill(onmt.Constants.UNK) + end + local tfeat = tds.Vec(#firstBatch.targetInputFeatures) + for fi = 1, #firstBatch.targetInputFeatures do + tfeat[fi] = torch.LongTensor(trainDataset.maxTargetLength):fill(onmt.Constants.UNK) end + while #src < args.max_batch_size do + table.insert(src, torch.LongTensor(src_sentmax):fill(onmt.Constants.UNK)) + if #firstBatch.sourceInputFeatures > 0 then + table.insert(srcFeats, sfeat) + end + table.insert(tgt, torch.LongTensor(trainDataset.maxTargetLength):fill(onmt.Constants.UNK)) + if #firstBatch.targetInputFeatures > 0 then + table.insert(tgtFeats, tfeat) + end + end + + -- memory should be stable now + local b = onmt.data.Batch.new(src, srcFeats, tgt, tgtFeats) + onmt.utils.Cuda.convert(b) + + model:trainNetwork(b, true) + end -- Add profiling hooks. |