Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/onmt
diff options
context:
space:
mode:
authorVincent Nguyen <vince62s@yahoo.com>2018-03-05 23:25:18 +0300
committerJean Senellart <jean@senellart.com>2018-03-05 23:25:18 +0300
commit560a5752f3948818f88c19f0265b3e4fc32969f7 (patch)
tree14dd8fc28f8120396ec423eed22072e88d03ff35 /onmt
parent5295476b19bd0f8c3ae81b9504770b8b7708763e (diff)
WIP introduce token batch size limit (#506)
* introduce token batch size limit * add memory preallocation so that varying batch size do not cause memory fragmentation. * update changelog * extension to sampled dataset * further optimization by considering only max observed max source and target sentences.
Diffstat (limited to 'onmt')
-rw-r--r--onmt/data/Dataset.lua7
-rw-r--r--onmt/data/DynamicDataset.lua7
-rw-r--r--onmt/data/SampledDataset.lua7
-rw-r--r--onmt/train/Optim.lua9
-rw-r--r--onmt/train/Trainer.lua45
5 files changed, 63 insertions, 12 deletions
diff --git a/onmt/data/Dataset.lua b/onmt/data/Dataset.lua
index 49f9cdc3..846dabd5 100644
--- a/onmt/data/Dataset.lua
+++ b/onmt/data/Dataset.lua
@@ -18,7 +18,7 @@ end
--[[ Setup up the training data to respect `maxBatchSize`.
If uneven_batches - then build up batches with different lengths ]]
-function Dataset:setBatchSize(maxBatchSize, uneven_batches)
+function Dataset:setBatchSize(maxBatchSize, maxTokens, uneven_batches)
self.batchRange = {}
self.maxSourceLength = 0
@@ -32,16 +32,19 @@ function Dataset:setBatchSize(maxBatchSize, uneven_batches)
local batchSize = 1
local maxSourceLength = 0
local targetLength = 0
+ local TokensInBatch = 0
for i = 1, #self.src do
-- Set up the offsets to make same source size batches of the
-- correct size.
local sourceLength = self.src[i]:size(1)
- if batchSize == maxBatchSize or i == 1 or
+ TokensInBatch = TokensInBatch + sourceLength
+ if TokensInBatch > maxTokens or i==1 or batchSize == maxBatchSize or
(not(uneven_batches) and self.src[i]:size(1) ~= maxSourceLength) then
if i > 1 then
batchesCapacity = batchesCapacity + batchSize * maxSourceLength
table.insert(self.batchRange, { ["begin"] = offset, ["end"] = i - 1 })
+ TokensInBatch = sourceLength
end
offset = i
diff --git a/onmt/data/DynamicDataset.lua b/onmt/data/DynamicDataset.lua
index cb0ef8f9..69dcd7df 100644
--- a/onmt/data/DynamicDataset.lua
+++ b/onmt/data/DynamicDataset.lua
@@ -8,7 +8,7 @@ function DynamicDataset:__init(opt, ddr)
end
--[[ define batch size ]]
-function DynamicDataset:setBatchSize(maxBatchSize, uneven_batches)
+function DynamicDataset:setBatchSize(maxBatchSize, maxTokens, uneven_batches)
-- time to build first sample
local data = self.ddr.preprocessor:makeData('train', self.ddr.dicts)
self.first = true
@@ -18,8 +18,9 @@ function DynamicDataset:setBatchSize(maxBatchSize, uneven_batches)
self.dataset = onmt.data.Dataset.new(data.src, data.tgt)
end
self.maxBatchSize = maxBatchSize
+ self.maxTokens = maxTokens
self.uneven_batches = uneven_batches
- local nTrainBatch, batchUsage = self.dataset:setBatchSize(maxBatchSize, uneven_batches)
+ local nTrainBatch, batchUsage = self.dataset:setBatchSize(maxBatchSize, maxTokens, uneven_batches)
self.src = self.dataset.src
self.tgt = self.dataset.tgt
self.maxSourceLength = self.dataset.maxSourceLength
@@ -42,7 +43,7 @@ function DynamicDataset:sample()
self.src = self.dataset.src
self.tgt = self.dataset.tgt
self:sampleVocabInit(self.opt, self.src, self.tgt)
- local nTrainBatch, _ = self.dataset:setBatchSize(self.maxBatchSize, self.uneven_batches)
+ local nTrainBatch, _ = self.dataset:setBatchSize(self.maxBatchSize, self.maxTokens, self.uneven_batches)
self.maxSourceLength = self.dataset.maxSourceLength
self.maxTargetLength = self.dataset.maxTargetLength
_G.logger:info('Sampling completed - %d sentences, %d mini-batch', #self.src, nTrainBatch)
diff --git a/onmt/data/SampledDataset.lua b/onmt/data/SampledDataset.lua
index 6aac90bb..cdfc3dd7 100644
--- a/onmt/data/SampledDataset.lua
+++ b/onmt/data/SampledDataset.lua
@@ -74,6 +74,7 @@ function SampledDataset:__init(opt, srcData, tgtData)
self.samplingProb = torch.ones(#self.src)
end
+ self.maxTokens = opt.max_tokens
self.sampled = nil
self.sampledCnt = torch.zeros(#self.src)
end
@@ -212,10 +213,13 @@ function SampledDataset:sample(logLevel)
local sampleCntBegin = 1
local batchSize = 1
local maxSourceLength = -1
+ local TokensInBatch = 0
+
for i = 1, #self.src do
for j = 1, self.sampledCnt[i] do
local sourceLength = self.src[i]:size(1)
- if batchSize == self.maxBatchSize or offset == 1 or
+ TokensInBatch = TokensInBatch + sourceLength
+ if TokensInBatch > self.maxTokens or batchSize == self.maxBatchSize or offset == 1 or
(not(self.uneven_batches) and self.src[i]:size(1) ~= maxSourceLength) then
if offset > 0 then
batchesCapacity = batchesCapacity + batchSize * maxSourceLength
@@ -228,6 +232,7 @@ function SampledDataset:sample(logLevel)
["sampleCntEnd"] = sampleCntEnd
})
sampleCntBegin = (j == 1) and 1 or j
+ TokensInBatch = sourceLength
end
offset = i
batchSize = 1
diff --git a/onmt/train/Optim.lua b/onmt/train/Optim.lua
index b0faae84..96109d7d 100644
--- a/onmt/train/Optim.lua
+++ b/onmt/train/Optim.lua
@@ -2,13 +2,20 @@ local Optim = torch.class('Optim')
local options = {
{
- '-max_batch_size', 64,
+ '-max_batch_size', 160,
[[Maximum batch size.]],
{
valid = onmt.utils.ExtendedCmdLine.isUInt()
}
},
{
+ '-max_tokens', 1800,
+ [[Maximum tokens in a batch.]],
+ {
+ valid = onmt.utils.ExtendedCmdLine.isUInt()
+ }
+ },
+ {
'-uneven_batches', false,
[[If set, batches are filled up to `-max_batch_size` even if the source lengths are different.
Slower but needed for some tasks.]]
diff --git a/onmt/train/Trainer.lua b/onmt/train/Trainer.lua
index 94814ad0..72beaa83 100644
--- a/onmt/train/Trainer.lua
+++ b/onmt/train/Trainer.lua
@@ -1,3 +1,5 @@
+tds = require 'tds'
+
local Trainer = torch.class('Trainer')
local options = {
@@ -94,7 +96,7 @@ function Trainer.declareOpts(cmd)
onmt.translate.Translator.declareOpts(cmd)
end
-function Trainer:__init(args, model, dicts, firstBatch)
+function Trainer:__init(args, model, dicts, trainDataset)
self.args = onmt.utils.ExtendedCmdLine.getModuleOpts(args, options)
self.args.profiler = args.profiler
self.args.disable_mem_optimization = args.disable_mem_optimization
@@ -134,11 +136,44 @@ function Trainer:__init(args, model, dicts, firstBatch)
-- If enabled, share internal buffers to optimize for memory.
if not self.args.disable_mem_optimization then
- if not firstBatch then
- _G.logger:error('A first batch is needed to optimize the computation graph for memory')
- else
- onmt.utils.Memory.optimize(model, onmt.utils.Cuda.convert(firstBatch))
+ local firstBatch = trainDataset:getBatch(1)
+ onmt.utils.Memory.optimize(model, onmt.utils.Cuda.convert(firstBatch))
+
+ _G.logger:info('Preallocating memory')
+ -- preallocate memory in encoder and decoder with highest batch size, and longest sentence
+ -- for target, we take args.tgt_seq_length as the longest possible (worse case). This could be refined
+ -- since in batch building, we can calculate tgt_max_tokens
+ local src_sentmax = math.min(trainDataset.maxSourceLength, math.ceil(args.max_tokens/args.max_batch_size))
+
+ local src = {}
+ local srcFeats = {}
+ local tgt = {}
+ local tgtFeats = {}
+ local sfeat = tds.Vec(#firstBatch.sourceInputFeatures)
+ for fi = 1, #firstBatch.sourceInputFeatures do
+ sfeat[fi] = torch.LongTensor(src_sentmax):fill(onmt.Constants.UNK)
+ end
+ local tfeat = tds.Vec(#firstBatch.targetInputFeatures)
+ for fi = 1, #firstBatch.targetInputFeatures do
+ tfeat[fi] = torch.LongTensor(trainDataset.maxTargetLength):fill(onmt.Constants.UNK)
end
+ while #src < args.max_batch_size do
+ table.insert(src, torch.LongTensor(src_sentmax):fill(onmt.Constants.UNK))
+ if #firstBatch.sourceInputFeatures > 0 then
+ table.insert(srcFeats, sfeat)
+ end
+ table.insert(tgt, torch.LongTensor(trainDataset.maxTargetLength):fill(onmt.Constants.UNK))
+ if #firstBatch.targetInputFeatures > 0 then
+ table.insert(tgtFeats, tfeat)
+ end
+ end
+
+ -- memory should be stable now
+ local b = onmt.data.Batch.new(src, srcFeats, tgt, tgtFeats)
+ onmt.utils.Cuda.convert(b)
+
+ model:trainNetwork(b, true)
+
end
-- Add profiling hooks.