Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/OpenNMT.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/onmt
diff options
context:
space:
mode:
authorJungi Kim <jungi.kim@gmail.com>2017-11-23 12:40:21 +0300
committerGuillaume Klein <guillaumekln@users.noreply.github.com>2017-11-23 12:40:21 +0300
commitb0a05bbdf672870837fb8a5c5020497e2ab68bfa (patch)
tree16cb0adfaa7342889af3a239dde36d31e0ff4f0c /onmt
parentf054ac998d7a2873c6a955e6f39afbdf1cb986e3 (diff)
Sentence-level negative log-likelihood criterion for SeqTagger (#438)
* Initial commit of Sentence-level Log-likelihood Criterion * Sentence-level Log-likelihood Criterion: minor optimization * Sentence-level Log-likelihood Criterion: memory optimization (minimize new allocation) * Sentence-level Log-likelihood Criterion: release() should remove derivatives as well * Sentence-level Log-likelihood Criterion: init impl of batch-vectorized Viterbi search * Sentence-level Log-likelihood Criterion: memory optimization for vectorized Viterbi search * Sentence-level Log-likelihood Criterion: vectorized forward * Sentence-level Log-likelihood Criterion: vectorized backward * Sentence-level Log-likelihood Criterion: remove unused code * Sentence-level Log-likelihood Criterion: clear luacheck warnings * Update changelog and options listing
Diffstat (limited to 'onmt')
-rw-r--r--onmt/Factory.lua4
-rw-r--r--onmt/SeqTagger.lua181
-rw-r--r--onmt/modules/SentenceNLLCriterion.lua644
-rw-r--r--onmt/modules/init.lua2
-rw-r--r--onmt/tagger/Tagger.lua112
5 files changed, 889 insertions, 54 deletions
diff --git a/onmt/Factory.lua b/onmt/Factory.lua
index 6350a087..b197734b 100644
--- a/onmt/Factory.lua
+++ b/onmt/Factory.lua
@@ -283,4 +283,8 @@ function Factory.buildAttention(args)
end
end
+function Factory.loadSentenceNLLCriterion(pretrained)
+ return onmt.SentenceNLLCriterion.load(pretrained)
+end
+
return Factory
diff --git a/onmt/SeqTagger.lua b/onmt/SeqTagger.lua
index 7728adb7..9b41ec69 100644
--- a/onmt/SeqTagger.lua
+++ b/onmt/SeqTagger.lua
@@ -49,6 +49,16 @@ local options = {
valid = onmt.utils.ExtendedCmdLine.isUInt(),
structural = 0
}
+ },
+ {
+ '-loglikelihood', 'word',
+ [[Specifies the type of loglikelihood of the tagger model;
+ 'word' indicates tags are predicted at the word-level, and
+ 'sentence' indicates tagging process is treated as a markov chain]],
+ {
+ enum = {'word', 'sentence'},
+ structural = 1
+ }
}
}
@@ -70,7 +80,18 @@ function SeqTagger:__init(args, dicts)
self.models.encoder = onmt.Factory.buildWordEncoder(args, dicts.src)
self.models.generator = onmt.Factory.buildGenerator(args, dicts.tgt)
- self.criterion = onmt.ParallelClassNLLCriterion(onmt.Factory.getOutputSizes(dicts.tgt))
+
+ onmt.utils.Error.assert(args.loglikelihood == 'word' or args.loglikelihood == 'sentence',
+ 'Invalid loglikelihood type of SeqTagger `%s\'', args.loglikelihood)
+
+ self.loglikelihood = args.loglikelihood
+
+ if self.loglikelihood == 'word' then
+ self.criterion = onmt.ParallelClassNLLCriterion(onmt.Factory.getOutputSizes(dicts.tgt))
+ elseif self.loglikelihood == 'sentence' then
+ self.criterion = onmt.SentenceNLLCriterion(args, onmt.Factory.getOutputSizes(dicts.tgt))
+ self.models.criterion = self.criterion -- criterion is model parameter
+ end
end
function SeqTagger.load(args, models, dicts)
@@ -81,7 +102,27 @@ function SeqTagger.load(args, models, dicts)
self.models.encoder = onmt.Factory.loadEncoder(models.encoder)
self.models.generator = onmt.Factory.loadGenerator(models.generator)
- self.criterion = onmt.ParallelClassNLLCriterion(onmt.Factory.getOutputSizes(dicts.tgt))
+
+ onmt.utils.Error.assert(args.loglikelihood == 'word' or args.loglikelihood == 'sentence',
+ 'Invalid loglikelihood type of SeqTagger `%s\'', args.loglikelihood)
+
+ if args.loglikelihood == 'word' then
+ self.criterion = onmt.ParallelClassNLLCriterion(onmt.Factory.getOutputSizes(dicts.tgt))
+ self.loglikelihood = 'word'
+ elseif args.loglikelihood == 'sentence' then
+ if not models.criterion then -- loading pre-trained word model to further train sentence model
+ _G.logger:info('Creating a new SentenceNLLCriterion')
+ self.criterion = onmt.SentenceNLLCriterion(args, onmt.Factory.getOutputSizes(dicts.tgt))
+ local p, g = self.criterion:getParameters()
+ p:uniform(-args.param_init, args.param_init)
+ g:uniform(-args.param_init, args.param_init)
+ self.criterion:postParametersInitialization()
+ else
+ self.criterion = onmt.Factory.loadSentenceNLLCriterion(models.criterion)
+ end
+ self.models.criterion = self.criterion
+ self.loglikelihood = 'sentence'
+ end
return self
end
@@ -115,15 +156,28 @@ function SeqTagger:forwardComputeLoss(batch)
local loss = 0
- for t = 1, batch.sourceLength do
- local genOutputs = self.models.generator:forward(context:select(2, t))
+ if self.loglikelihood == 'sentence' then
+ local reference = batch.targetOutput:t() -- SeqLen x B -> B x SeqLen
+ local tagsScoreTable = {}
+ for t = 1, batch.sourceLength do
+ local tagsScore = self.models.generator:forward(context:select(2, t)) -- B x TagSize
+ -- tagsScore is a table
+ tagsScore = nn.utils.addSingletonDimension(tagsScore[1], 3):clone() -- B x TagSize x 1
+ table.insert(tagsScoreTable, tagsScore)
+ end
+ local tagsScores = nn.JoinTable(3):forward(tagsScoreTable) -- B x TagSize x SeqLen
+ loss = self.models.criterion:forward(tagsScores, reference)
+ else -- 'word'
+ for t = 1, batch.sourceLength do
+ local genOutputs = self.models.generator:forward(context:select(2, t))
- local output = batch:getTargetOutput(t)
+ local output = batch:getTargetOutput(t)
- -- Same format with and without features.
- if torch.type(output) ~= 'table' then output = { output } end
+ -- Same format with and without features.
+ if torch.type(output) ~= 'table' then output = { output } end
- loss = loss + self.criterion:forward(genOutputs, output)
+ loss = loss + self.criterion:forward(genOutputs, output)
+ end
end
return loss
@@ -136,23 +190,50 @@ function SeqTagger:trainNetwork(batch)
local gradContexts = context:clone():zero()
- -- For each word of the sentence, generate target.
- for t = 1, batch.sourceLength do
- local genOutputs = self.models.generator:forward(context:select(2, t))
+ if self.loglikelihood == 'sentence' then
- local output = batch:getTargetOutput(t)
+ local reference = batch.targetOutput:t() -- SeqLen x B -> B x SeqLen
+ local B = batch.size
+ local T = batch.sourceLength
- -- Same format with and without features.
- if torch.type(output) ~= 'table' then output = { output } end
+ local tagsScoreTable = {}
+ for t = 1, T do
+ local tagsScore = self.models.generator:forward(context:select(2, t)) -- B x TagSize
+ -- tagsScore is a table
+ tagsScore = nn.utils.addSingletonDimension(tagsScore[1], 3):clone() -- B x TagSize x 1
+ table.insert(tagsScoreTable, tagsScore)
+ end
+ local tagsScores = nn.JoinTable(3):forward(tagsScoreTable) -- B x TagSize x SeqLen
- loss = loss + self.criterion:forward(genOutputs, output)
+ loss = loss + self.criterion:forward(tagsScores, reference)
- local genGradOutput = self.criterion:backward(genOutputs, output)
- for j = 1, #genGradOutput do
- genGradOutput[j]:div(batch.totalSize)
+ local gradCriterion = self.models.criterion:backward(tagsScores, reference) -- B x TagSize x SeqLen
+
+ gradCriterion = torch.div(gradCriterion, B)
+ for t = 1, T do
+ gradContexts:select(2,t):copy(self.models.generator:backward(context:select(2, t), {gradCriterion:select(3,t)}))
end
- gradContexts[{{}, t}]:copy(self.models.generator:backward(context:select(2, t), genGradOutput))
+ else -- 'word'
+ -- For each word of the sentence, generate target.
+ for t = 1, batch.sourceLength do
+ local genOutputs = self.models.generator:forward(context:select(2, t))
+
+ local output = batch:getTargetOutput(t)
+
+ -- Same format with and without features.
+ if torch.type(output) ~= 'table' then output = { output } end
+
+ loss = loss + self.criterion:forward(genOutputs, output)
+
+ local genGradOutput = self.criterion:backward(genOutputs, output)
+
+ for j = 1, #genGradOutput do
+ genGradOutput[j]:div(batch.totalSize)
+ end
+
+ gradContexts[{{}, t}]:copy(self.models.generator:backward(context:select(2, t), genGradOutput))
+ end
end
self.models.encoder:backward(batch, nil, gradContexts)
@@ -160,4 +241,66 @@ function SeqTagger:trainNetwork(batch)
return loss
end
+function SeqTagger:tagBatch(batch)
+ local pred = {}
+ local feats = {}
+
+ for _ = 1, batch.size do
+ table.insert(pred, {})
+ table.insert(feats, {})
+ end
+ local _, context = self.models.encoder:forward(batch)
+
+ if self.loglikelihood == 'sentence' then
+
+ local tagsScoreTable = {}
+ for t = 1, batch.sourceLength do
+ local tagsScore = self.models.generator:forward(context:select(2, t)) -- B x TagSize
+ -- tagsScore is a table
+ tagsScore = nn.utils.addSingletonDimension(tagsScore[1], 3):clone() -- B x TagSize x 1
+ table.insert(tagsScoreTable, tagsScore)
+ end
+ local tagsScores = onmt.utils.Cuda.convert(nn.JoinTable(3):forward(tagsScoreTable)) -- B x TagSize x SeqLen
+
+ -- viterbi search
+ local senPreds = self.criterion:viterbiSearch(tagsScores, batch.sourceSize) -- B x SeqLen (type Long)
+
+ for t = 1, batch.sourceLength do
+ for b = 1, batch.size do
+ -- padded in the beginning
+ if t > batch.sourceLength - batch.sourceSize[b] then
+ pred[b][t - batch.sourceLength + batch.sourceSize[b]] = senPreds[b][t]
+ feats[b][t - batch.sourceLength + batch.sourceSize[b]] = {}
+ end
+ end
+ end
+
+ else -- 'word'
+ for t = 1, batch.sourceLength do
+ local out = self.models.generator:forward(context:select(2, t))
+ if type(out[1]) == 'table' then
+ out = out[1]
+ end
+ local _, best = out[1]:max(2)
+ for b = 1, batch.size do
+ if t > batch.sourceLength - batch.sourceSize[b] then
+ pred[b][t - batch.sourceLength + batch.sourceSize[b]] = best[b][1]
+ feats[b][t - batch.sourceLength + batch.sourceSize[b]] = {}
+ end
+ end
+ for j = 2, #out do
+ _, best = out[j]:max(2)
+ for b = 1, batch.size do
+ if t > batch.sourceLength - batch.sourceSize[b] then
+ feats[b][t - batch.sourceLength + batch.sourceSize[b]][j - 1] = best[b][1]
+ end
+ end
+ end
+ end
+
+ end
+
+ return pred, feats
+end
+
return SeqTagger
diff --git a/onmt/modules/SentenceNLLCriterion.lua b/onmt/modules/SentenceNLLCriterion.lua
new file mode 100644
index 00000000..5e2c5832
--- /dev/null
+++ b/onmt/modules/SentenceNLLCriterion.lua
@@ -0,0 +1,644 @@
+--[[
+ Define SentenceNLLCriterion.
+ Implements Sentence-level log-likelihood as described in
+ Collobert et al., Natural Language Processing (almost) from Scratch, JMLR 12(2011).
+
+ This class tries to be both nn.Criterion and nn.Module at the same time.
+ (Criterion with learnable parameters that are required for run-time.)
+
+ This module requires double-precision calculations so internally, input/model parameters/output are cloned as double
+ then converted back to moddel default types after the calculations.
+--]]
+local SentenceNLLCriterion, parent = torch.class('onmt.SentenceNLLCriterion', 'nn.Criterion')
+
+function SentenceNLLCriterion:__init(args, outputSize)
+ parent.__init(self)
+
+ if torch.type(outputSize) == 'table' then
+ outputSize = outputSize[1]
+ end
+
+ local N = outputSize
+ self.outputSize = N
+
+ self.A0 = torch.zeros(N) -- TagSize (N)
+ self.A = torch.zeros(N, N) -- TagSize (N) x TagSize (N)
+ self.dA0 = torch.zeros(N)
+ self.dA = torch.zeros(N,N)
+
+ if args.max_grad_norm then
+ self.max_grad_norm = args.max_grad_norm
+ else
+ self.max_grad_norm = 5
+ end
+end
+
+--[[ Return a new SentenceNLLCriterion using the serialized data `pretrained`. ]]
+function SentenceNLLCriterion.load(pretrained)
+ local self = torch.factory('onmt.SentenceNLLCriterion')()
+
+ parent.__init(self)
+ self.A0 = pretrained.A0
+ self.A = pretrained.A
+ self.outputSize = pretrained.outputSize
+ self.max_grad_norm = pretrained.max_grad_norm
+
+ return self
+end
+
+--[[ Return data to serialize. ]]
+function SentenceNLLCriterion:serialize()
+ return {
+ A0 = self.A0,
+ A = self.A,
+ outputSize = self.outputSize,
+ max_grad_norm = self.max_grad_norm,
+ float = self.float,
+ clearState = self.clearState,
+ apply = self.apply
+ }
+end
+
+function SentenceNLLCriterion:training()
+ self:_initTrainCache()
+end
+
+function SentenceNLLCriterion:evaluate()
+ self:renormalizeParams()
+end
+
+function SentenceNLLCriterion:release()
+ self:_freeTrainCache()
+ self:_freeViterbiCache()
+ self.dA0 = nil
+ self.dA = nil
+end
+
+function SentenceNLLCriterion:float()
+ if self.A0 then self.A0 = self.A0:float() end
+ if self.A then self.A = self.A:float() end
+ if self.dA0 then self.dA0 = self.dA0:float() end
+ if self.dA then self.dA = self.dA:float() end
+end
+
+function SentenceNLLCriterion:clearState()
+end
+
+function SentenceNLLCriterion:normalizeParams()
+ local N = self.outputSize
+
+ self.A0:add(-self.A0:min() + 0.000001)
+ self.A0:div(self.A0:sum())
+ self.A0:log()
+ self.A:add(-torch.min(self.A,2):expand(N, N) + 0.000001)
+ self.A:cdiv(self.A:sum(2):expand(N, N))
+ self.A:log()
+end
+
+function SentenceNLLCriterion:renormalizeParams()
+ self.A0:exp()
+ self.A:exp()
+ self:normalizeParams()
+end
+
+function SentenceNLLCriterion:postParametersInitialization()
+ self:normalizeParams()
+end
+
+function SentenceNLLCriterion:parameters()
+ return {self.A0, self.A}, {self.dA0, self.dA}
+end
+
+--[[
+ Viterbi search
+--]]
+function SentenceNLLCriterion:viterbiSearch(input, sourceSizes)
+ -- Input
+ -- input: BatchSize (B) x TagSize (N) x SeqLen (T): log-scale emission probabilities
+ -- sourceSizes: BatchSize (B) of data type Long
+ -- Output
+ -- preds: BatchSize (B) x SeqLen (T) of data type Long (index)
+
+ local F = input
+ local B = input:size(1)
+ local N = input:size(2) -- should equal self.outputSize
+ local T = input:size(3)
+ local _
+
+ if not self.cache_viterbi_preds then
+ self:_initViterbiCache()
+ end
+
+ local preds = self.cache_viterbi_preds:resize(B,T+1):zero() -- extra dimension in T for EOS handling
+
+ function SentenceNLLCriterion:_viterbiSearch_batch()
+ -- OpenNMT mini batches are padded to the left of source sequences
+ local isOnMask = self.cache_viterbi_isOnMask:resize(B,T+1):fill(1)
+ local isA0Mask = self.cache_viterbi_isA0Mask:resize(B,T+1):zero()
+ for b = 1, B do
+ for t = 1, (T - sourceSizes[b]) do
+ isOnMask[{b,t}] = 0
+ end
+ isA0Mask[{b, T+1-sourceSizes[b]}] = 1
+ end
+ local isAMask = self.cache_viterbi_isAMask:add(isOnMask, -isA0Mask)
+ local isMaxMask = isAMask
+ local isFMask = isOnMask[{{}, {1,T}}]
+
+ local maxScore = self.cache_viterbi_maxScore:resize(B, N, T+1)
+ local backPointer = self.cache_viterbi_backPointer:resize(B, N, T+1)
+
+ -- A0
+ local A0Score = nn.utils.addSingletonDimension(nn.utils.addSingletonDimension(self.A0, 1):expand(N, N), 1):expand(B, N, N)
+ -- A
+ local AScore = nn.utils.addSingletonDimension(self.A:t(), 1):expand(B, N, N)
+
+ for t = 1, T + 1 do
+ local scores = self.cache_viterbi_scores:resize(B,N,N):zero()
+
+ local A0ScoreMasked = self.cache_viterbi_XScoreMasked:resize(B,N,N)
+ A0ScoreMasked:cmul(A0Score, nn.utils.addSingletonDimension(nn.utils.addSingletonDimension(isA0Mask[{{},t}],2),3):expand(B, N, N))
+ scores:add(A0ScoreMasked)
+
+ if t > 1 then
+ local AScoreMasked = self.cache_viterbi_XScoreMasked:resize(B,N,N)
+ AScoreMasked:cmul(AScore, nn.utils.addSingletonDimension(nn.utils.addSingletonDimension(isAMask[{{},t}],2),3):expand(B, N, N))
+ scores:add(AScoreMasked)
+
+ -- maxScore
+ local MaxScore = nn.utils.addSingletonDimension(maxScore[{{},{},t-1}],2):expand(B, N, N)
+ local MaxScoreMasked = self.cache_viterbi_XScoreMasked:resize(B,N,N)
+ MaxScoreMasked:cmul(MaxScore, nn.utils.addSingletonDimension(nn.utils.addSingletonDimension(isMaxMask[{{},t}],2),3):expand(B, N, N))
+ scores:add(MaxScoreMasked)
+ end
+
+ if t < T + 1 then
+ -- F
+ local FScore = nn.utils.addSingletonDimension(F[{{},{},t}],3):expand(B, N, N)
+ local FScoreMasked = self.cache_viterbi_XScoreMasked:resize(B,N,N)
+ FScoreMasked:cmul(FScore, nn.utils.addSingletonDimension(nn.utils.addSingletonDimension(isFMask[{{},t}],2),3):expand(B, N, N))
+ scores:add(FScoreMasked)
+ end
+
+ maxScore[{{},{},t}], backPointer[{{},{},t}] = scores:max(3)
+ end
+
+ for b=1,B do
+ local pred = preds[b]
+ _, pred[T+1] = maxScore[{b,{},T+1}]:max(1)
+ for t=T+1,2+(T-sourceSizes[b]),-1 do
+ pred[t-1] = backPointer[{b,pred[t],t}]
+ end
+ end
+ end
+
+ function SentenceNLLCriterion:_viterbiSearch_loop()
+ local maxScore = onmt.utils.Cuda.convert(torch.Tensor(N, T+1))
+ local backPointer = onmt.utils.Cuda.convert(torch.LongTensor(N, T+1))
+
+ for b = 1, B do
+ maxScore:zero()
+ backPointer:zero()
+
+ -- OpenNMT mini batches are padded to the left of source sequences
+ local tOffset = T - sourceSizes[b]
+
+ maxScore[{{},1+tOffset}], backPointer[{{},1+tOffset}] = nn.utils.addSingletonDimension(torch.add(self.A0, F[{b,{},1+tOffset}]),2):max(2)
+
+ for t = 2+tOffset, T+1 do
+ local scores = torch.add(nn.utils.addSingletonDimension(maxScore[{{},t-1}],1):expand(N, N), self.A:t())
+ if t <= T then
+ scores:add(nn.utils.addSingletonDimension(F[{b,{},t}],2):expand(N,N))
+ end
+ maxScore[{{},t}], backPointer[{{},t}] = scores:max(2)
+ end
+
+ local pred = preds[b]
+ _, pred[T+1] = maxScore[{{},T+1}]:max(1)
+ for t=T+1,2+tOffset,-1 do
+ pred[t-1] = backPointer[{pred[t], t}]
+ end
+ end
+ end
+
+ self:_viterbiSearch_batch()
+-- self:_viterbiSearch_loop()
+
+ return preds[{{},{1,T}}]:clone()
+end
+
+function SentenceNLLCriterion:logsumexp(x)
+ -- Input
+ -- x: TagSize (N) or TagSize (N) x TagSize (N)
+
+ local N = x:size(1) -- should equal self.outputSize
+
+ local max, _ = x:max(1) -- 1 or 1 x N
+ local log_sum_exp
+ if x:nDimension() == 1 then
+ log_sum_exp = (x - max:expand(N)):exp():sum(1):log() -- 1
+ else
+ log_sum_exp = (x - max:expand(N,N)):exp():sum(1):log() -- 1 x N
+ end
+ -- find NaN values and assign a valid value
+ local NaN_mask = log_sum_exp:ne(log_sum_exp)
+ log_sum_exp[NaN_mask] = max:max()
+ return log_sum_exp:add(max):squeeze(1) -- 1 or N
+end
+
+function SentenceNLLCriterion:logsumexp_batch(x)
+ -- Input
+ -- x: B x N or B x N x N
+
+ local B = x:size(1)
+ local N = x:size(2) -- should equal self.outputSize
+
+ local max, _ = x:max(2) -- B x 1 or B x 1 x N
+ local log_sum_exp
+ if x:nDimension() == 2 then
+ log_sum_exp = (x - max:expand(B,N)):exp():sum(2):log() -- B x 1
+ else
+ log_sum_exp = (x - max:expand(B,N,N)):exp():sum(2):log() -- B x 1 x N
+ end
+ -- find NaN values and assign a valid value
+ local NaN_mask = log_sum_exp:ne(log_sum_exp)
+ log_sum_exp[NaN_mask] = max:max()
+ return log_sum_exp:add(max):squeeze(2) -- B x 1 or B x N
+end
+
+function SentenceNLLCriterion:updateOutput(input, target)
+
+ -- Input variables
+ -- input: BatchSize (B) x TagSize (N) x SeqLen (T)
+ -- target: BatchSize (B) x SeqLen (T)
+
+ -- Output variable
+ -- loss
+
+ local Y = target
+ local B = input:size(1)
+ local N = input:size(2) -- should equal self.outputSize
+ local T = input:size(3) + 1 -- extra T dimension for EOS
+
+ if not self.cache_F then -- Initialize cache when updateOutput() is called at inference time (e.g. to compute loss)
+ self:_initTrainCache()
+ end
+
+ local F = self.cache_F:resize(B, N, T)
+ F[{{},{},{1,input:size(3)}}]:copy(input)
+ F[{{}, {}, -1}] = 0.000001
+ F[{{}, onmt.Constants.EOS, -1}] = 1
+ F[{{}, {}, -1}]:log()
+
+ self.cache_delta:resize(F:size()):zero() -- B,N,T
+
+ self.cache_A0_dtype = self.A0:type(self.dtype)
+ self.cache_A_dtype = self.A:type(self.dtype)
+
+ local loss = 0.0
+
+ function SentenceNLLCriterion:_updateOutput_batch()
+ -- OpenNMT mini batches are padded to the left of source sequences
+ local isOnMask = self.cache_isOnMask:resize(B,T):fill(1)
+ local isA0Mask = self.cache_isA0Mask:resize(B,T):zero()
+ local yLookupTensor = self.cache_dF:resize(B,N,T):zero()
+ for b = 1, B do
+ for t = 1, T do
+ if Y[b][t] == onmt.Constants.PAD then
+ isOnMask[b][t] = 0
+ else
+ if t == 1 or (t > 1 and Y[b][t-1] == onmt.Constants.PAD) then
+ isA0Mask[b][t] = 1
+ end
+ yLookupTensor[b][ Y[b][t] ][t] = 1
+ end
+ end
+ end
+ local isAMask = self.cache_isAMask:add(isOnMask, -isA0Mask)
+ local isFMask = isOnMask
+
+ local delta = self.cache_delta
+ local delta_tmp = self.cache_delta_tmp:resize(B,N,N)
+
+ local refScores = torch.Tensor(B):type(self.dtype):zero()
+ local logLiks = torch.Tensor(B):type(self.dtype):zero()
+
+ -- A0
+ self.cache_A0_dtype_batch = nn.utils.addSingletonDimension(self.cache_A0_dtype, 1):expand(B,N)
+ local A0_dtype_batch = self.cache_A0_dtype_batch
+ -- A
+ self.cache_A_dtype_batch = nn.utils.addSingletonDimension(self.cache_A_dtype, 1):expand(B,N,N)
+ local A_dtype_batch = self.cache_A_dtype_batch
+
+ for t = 1, T do
+ -- refScore
+ refScores:add(torch.cmul(A0_dtype_batch,
+ yLookupTensor[{{},{},t}])
+ :cmul(nn.utils.addSingletonDimension(isA0Mask[{{},t}],2):expand(B,N))
+ :sum(2):squeeze(2) -- B
+ )
+ refScores:add(torch.cmul(F[{{},{},t}], yLookupTensor[{{},{},t}]):sum(2):squeeze(2)) -- B; yLookupTensor is already masked
+ if t > 1 then
+ refScores:add(torch.cmul(A_dtype_batch,
+ nn.utils.addSingletonDimension(yLookupTensor[{{},{},t-1}],3):expand(B,N,N)):sum(2) -- select t-1
+ :cmul(yLookupTensor[{{},{},t}]):sum(3):squeeze(3):squeeze(2) -- select t
+ ) -- B; yLookupTensor is already masked
+ end
+
+ --loglik
+ delta[{{},{},t}]:add(torch.cmul(A0_dtype_batch, nn.utils.addSingletonDimension(isA0Mask[{{},t}],2):expand(B,N))) -- B x N
+ delta[{{},{},t}]:add(torch.cmul(F[{{},{},t}], nn.utils.addSingletonDimension(isFMask[{{},t}],2):expand(B,N)))
+ if t > 1 then
+ delta_tmp:add( A_dtype_batch,
+ nn.utils.addSingletonDimension(delta[{{},{},t-1}],3):expand(B,N,N)
+ ):cmul(nn.utils.addSingletonDimension(nn.utils.addSingletonDimension(isAMask[{{},t}],2),3):expand(B,N,N)) -- B x N x N
+
+ delta[{{},{},t}]:add(self:logsumexp_batch(delta_tmp)) -- B x N
+ end
+ end
+
+ logLiks:add(refScores, -self:logsumexp_batch(delta[{{},{},T}]))
+ loss = -logLiks:sum()
+ end
+
+ function SentenceNLLCriterion:_updateOutput_loop()
+ for b = 1, B do
+
+ local delta = self.cache_delta[b]
+
+ -- OpenNMT mini batches are currently padded to the left of source sequences
+ local tOffset = 0
+ while Y[{b,1+tOffset}] == onmt.Constants.PAD do
+ tOffset = tOffset + 1
+ end
+
+ -- init state
+ local t_1 = 1 + tOffset
+ local referenceScore = self.A0[Y[b][t_1]] + F[b][Y[b][t_1]][t_1]
+ delta[{{},t_1}]:add(self.cache_A0_dtype, F[{b,{},t_1}])
+
+ -- fwd transition recursion
+ for t = 2 + tOffset, T do
+ local Y_t = Y[b][t]
+ local Y_t_1 = Y[b][t-1]
+
+ referenceScore = referenceScore + self.A[Y_t_1][Y_t] + F[b][Y_t][t]
+
+ self.cache_delta_tmp:add(self.cache_A_dtype, nn.utils.addSingletonDimension(delta[{{},t-1}],2):expand(N,N))
+ delta[{{},t}]:add(F[{b,{},t}], self:logsumexp(self.cache_delta_tmp))
+ end
+
+ local loglik = referenceScore - self:logsumexp(delta[{{},T}])
+ loss = loss - loglik
+ end
+ end
+
+ self:_updateOutput_batch()
+-- self:_updateOutput_loop()
+
+ return loss
+end
+
+function SentenceNLLCriterion:updateGradInput(input, target)
+ -- Input: F, A0, A
+ -- Output: dF, dA0, dA w.r.t Loss in target
+
+ local dF = self.cache_dF:resize(input:size()):zero()
+ local Y = target
+ local B = input:size(1)
+ local N = input:size(2) -- should equal self.outputSize
+ local T = input:size(3)+1
+
+ local dA_sum = self.cache_dA_sum:zero()
+ local dA0_sum = self.cache_dA0_sum:zero()
+
+ function SentenceNLLCriterion:_updateGradInput_batch()
+
+ local delta = self.cache_delta -- B x N x T; cached calculations from fwd path
+
+ -- Assume cached masks and batched A0, A are still valid
+ local isA0Mask = self.cache_isA0Mask
+ local isFMask = self.cache_isOnMask
+
+ local A_dtype_batch = self.cache_A_dtype_batch
+ local path_transition_probs = self.cache_path_transition_probs:resize(B,N,N)
+ local dA = self.cache_dA_tmp:resize(B,N,N):zero() -- B x N x N
+ local dA0 = self.cache_dA0_tmp:resize(B,N):zero() -- B x N
+
+ for b = 1, B do
+ for t = 1, T do
+ if Y[b][t] ~= onmt.Constants.PAD then
+ if t == 1 or (t > 1 and Y[b][t-1] == onmt.Constants.PAD) then
+ dA0[{b,Y[b][t]}] = dA0[{b,Y[b][t]}] - 1 -- A0
+ end
+ if t > 1 then
+ dA[{b,Y[b][t-1],Y[b][t]}] = dA[{b,Y[b][t-1],Y[b][t]}] - 1 -- A
+ end
+ if t < T then -- dF for the last EOS token does not exist
+ dF[{b,Y[b][t],t}] = dF[{b,Y[b][t],t}] - 1 -- F
+ end
+ end
+ end
+ end
+
+ local deriv_Clogadd = delta[{{},{},T}]:exp()
+ deriv_Clogadd:cdiv(deriv_Clogadd:sum(2):expand(B,N))
+ deriv_Clogadd[deriv_Clogadd:ne(deriv_Clogadd)] = 0
+
+ for t = T,1,-1 do
+ if t < T then
+ -- F
+ dF[{{},{},t}]:add(torch.cmul(deriv_Clogadd, nn.utils.addSingletonDimension(isFMask[{{},t}],2):expand(B,N)))
+ -- A0
+ dA0:add(torch.cmul(deriv_Clogadd, nn.utils.addSingletonDimension(isA0Mask[{{},t}],2):expand(B,N)))
+ end
+ -- A
+ if t > 1 then
+ path_transition_probs:add(A_dtype_batch,
+ nn.utils.addSingletonDimension(
+ delta[{{},{},t-1}] -- delta is calculated from masked
+ ,3):expand(B,N,N))
+ path_transition_probs:exp()
+ path_transition_probs:cdiv(path_transition_probs:sum(2):expand(B,N,N))
+ path_transition_probs[path_transition_probs:ne(path_transition_probs)] = 0
+ path_transition_probs:cmul(nn.utils.addSingletonDimension(deriv_Clogadd, 2):expand(B,N,N))
+
+ local dAt = path_transition_probs
+ dA:add(dAt)
+
+ deriv_Clogadd = dAt:sum(3):squeeze(3)
+ for b = 1, B do
+ onmt.train.Optim.clipGradByNorm({deriv_Clogadd[b]}, self.max_grad_norm)
+ end
+ end
+ end
+
+ dA_sum:add(dA:sum(1):squeeze(1))
+ dA0_sum:add(dA0:sum(1):squeeze(1))
+ end
+
+ function SentenceNLLCriterion:_updateGradInput_loop()
+ local A_dtype = self.cache_A_dtype
+
+ for b = 1, B do
+ -- OpenNMT mini batches are padded to the left of source sequences
+ local tOffset = 0
+ while Y[{b,1+tOffset}] == onmt.Constants.PAD do
+ tOffset = tOffset + 1
+ end
+
+ local delta = self.cache_delta[b] -- N x T
+
+ local dA = self.cache_dA_tmp:zero()
+ local dA0 = self.cache_dA0_tmp:zero()
+
+ local deriv_Clogadd = delta[{{},T}]:exp()
+ deriv_Clogadd:div(deriv_Clogadd:sum())
+ deriv_Clogadd[deriv_Clogadd:ne(deriv_Clogadd)] = 0
+
+ for t= T, (2+tOffset), -1 do
+ local Y_t = Y[b][t]
+ local Y_t_1 = Y[b][t-1]
+
+ if t < T then -- dF for the last EOS token does not exist
+ dF[{b,Y_t,t}] = dF[{b,Y_t,t}] - 1
+ end
+ dA[{Y_t_1,Y_t}] = dA[{Y_t_1,Y_t}] - 1
+
+ -- compute and add partial derivatives w.r.t transition scores
+ local path_transition_probs = self.cache_path_transition_probs
+ path_transition_probs:add(A_dtype, nn.utils.addSingletonDimension(delta[{{},t-1}],2):expand(N,N))
+ path_transition_probs:exp()
+ path_transition_probs:cdiv(path_transition_probs:sum(1):expand(N,N))
+ path_transition_probs[path_transition_probs:ne(path_transition_probs)] = 0
+
+ if t < T then
+ dF[{b,{},t}]:add(deriv_Clogadd)
+ end
+ path_transition_probs:cmul(nn.utils.addSingletonDimension(deriv_Clogadd, 1):expand(N,N))
+ local dAt = path_transition_probs
+ dA:add(dAt)
+ deriv_Clogadd = dAt:sum(2):squeeze(2)
+
+ onmt.train.Optim.clipGradByNorm({deriv_Clogadd}, self.max_grad_norm)
+ end
+
+ local t = 1 + tOffset
+ local Y_t = Y[b][t]
+ dF[{b,Y_t,t}] = dF[{b,Y_t,t}] - 1
+ dA0[Y_t] = dA0[Y_t] - 1
+
+ dF[{b,{},t}]:add(deriv_Clogadd)
+ dA0:add(deriv_Clogadd)
+
+ dA_sum:add(dA)
+ dA0_sum:add(dA0)
+ end
+ end
+
+-- self:_updateGradInput_loop()
+ self:_updateGradInput_batch()
+
+ self.dA:add(onmt.utils.Cuda.convert(dA_sum/B))
+ self.dA0:add(onmt.utils.Cuda.convert(dA0_sum/B))
+
+ if not self.gradInput then
+ self.gradInput = onmt.utils.Cuda.convert(torch.Tensor())
+ end
+ self.gradInput:resize(input:size())
+ self.gradInput:copy(dF)
+
+ return self.gradInput
+end
+
+function SentenceNLLCriterion:_initViterbiCache()
+ local N = self.outputSize
+ self.cache_viterbi_preds = onmt.utils.Cuda.convert(torch.LongTensor(1,1))
+ self.cache_viterbi_maxScore = onmt.utils.Cuda.convert(torch.Tensor(1, N, 1))
+ self.cache_viterbi_backPointer = onmt.utils.Cuda.convert(torch.LongTensor(1, N, 1))
+ self.cache_viterbi_isOnMask = onmt.utils.Cuda.convert(torch.Tensor(1, 1))
+ self.cache_viterbi_isA0Mask = onmt.utils.Cuda.convert(torch.Tensor(1, 1))
+ self.cache_viterbi_isAMask = onmt.utils.Cuda.convert(torch.Tensor(1, 1))
+ self.cache_viterbi_scores = onmt.utils.Cuda.convert(torch.Tensor(1, N, N))
+ self.cache_viterbi_XScoreMasked = onmt.utils.Cuda.convert(torch.Tensor(1, N, N))
+end
+
+function SentenceNLLCriterion:_freeViterbiCache()
+ self.cache_viterbi_preds = nil
+ self.cache_viterbi_maxScore = nil
+ self.cache_viterbi_backPointer = nil
+ self.cache_viterbi_isOnMask = nil
+ self.cache_viterbi_isA0Mask = nil
+ self.cache_viterbi_isAMask = nil
+ self.cache_viterbi_scores = nil
+ self.cache_viterbi_XScoreMasked = nil
+end
+
+function SentenceNLLCriterion:_initTrainCache()
+ local N = self.outputSize
+
+ self.dtype = onmt.utils.Cuda.activated and 'torch.CudaDoubleTensor' or 'torch.DoubleTensor'
+
+ self.cache_dA_sum = torch.Tensor():type(self.dtype):resize(N,N)
+ self.cache_dA0_sum = torch.Tensor():type(self.dtype):resize(N)
+ self.cache_dA_tmp = torch.Tensor():type(self.dtype):resize(N,N)
+ self.cache_dA0_tmp = torch.Tensor():type(self.dtype):resize(N)
+ self.cache_path_transition_probs = torch.Tensor():type(self.dtype):resize(N,N)
+
+ self.cache_F = torch.Tensor():type(self.dtype):resize(1,N,1)
+ self.cache_dF = torch.Tensor():type(self.dtype):resize(1,N,1)
+ self.cache_delta = torch.Tensor():type(self.dtype):resize(1,N,1)
+ self.cache_delta_tmp = torch.Tensor():type(self.dtype):resize(N,N)
+
+ self.cache_isOnMask = torch.Tensor(1, 1):type(self.dtype)
+ self.cache_isA0Mask = torch.Tensor(1, 1):type(self.dtype)
+ self.cache_isAMask = torch.Tensor(1, 1):type(self.dtype)
+end
+
+function SentenceNLLCriterion:_freeTrainCache()
+ self.dtype = nil
+ self.cache_dA_sum = nil
+ self.cache_dA0_sum = nil
+ self.cache_dA_tmp = nil
+ self.cache_dA0_tmp = nil
+ self.cache_path_transition_probs = nil
+
+ self.cache_F = nil
+ self.cache_dF = nil
+ self.cache_delta = nil
+ self.cache_delta_tmp = nil
+
+ self.cache_isOnMask = nil
+ self.cache_isA0Mask = nil
+ self.cache_isAMask = nil
+end
+
+--[[
+Copied from nn.Module
+--]]
+function SentenceNLLCriterion:getParameters()
+ -- get parameters
+ local parameters,gradParameters = self:parameters()
+ local p, g = nn.Module.flatten(parameters), nn.Module.flatten(gradParameters)
+ assert(p:nElement() == g:nElement(),
+ 'check that you are sharing parameters and gradParameters')
+ if parameters then
+ for i=1,#parameters do
+ assert(parameters[i]:storageOffset() == gradParameters[i]:storageOffset(),
+ 'misaligned parameter at ' .. tostring(i))
+ end
+ end
+ return p, g
+end
+
+--[[
+Copied from nn.Module
+--]]
+function SentenceNLLCriterion:apply(callback)
+ callback(self)
+ if self.modules then
+ for _, module in ipairs(self.modules) do
+ module:apply(callback)
+ end
+ end
+end
diff --git a/onmt/modules/init.lua b/onmt/modules/init.lua
index 4eeb42c3..c88e3d77 100644
--- a/onmt/modules/init.lua
+++ b/onmt/modules/init.lua
@@ -31,4 +31,6 @@ require('onmt.modules.JoinReplicateTable')
require('onmt.modules.ParallelClassNLLCriterion')
require('onmt.modules.RIndexLinear')
+require('onmt.modules.SentenceNLLCriterion')
+
return onmt
diff --git a/onmt/tagger/Tagger.lua b/onmt/tagger/Tagger.lua
index 64d152dd..ea4fa62e 100644
--- a/onmt/tagger/Tagger.lua
+++ b/onmt/tagger/Tagger.lua
@@ -61,6 +61,20 @@ function Tagger:buildInput(tokens)
return data
end
+function Tagger:buildInputGold(tokens)
+ local data = {}
+
+ local words, features = onmt.utils.Features.extract(tokens)
+
+ data.words = words
+
+ if #features > 0 then
+ data.features = features
+ end
+
+ return data
+end
+
function Tagger:buildOutput(data)
return table.concat(onmt.utils.Features.annotate(data.words, data.features), ' ')
end
@@ -97,6 +111,57 @@ function Tagger:buildData(src)
return onmt.data.Dataset.new(srcData), ignored, indexMap
end
+function Tagger:buildGoldData(src, tgt)
+ local srcData = {}
+ srcData.words = {}
+ srcData.features = {}
+
+ local tgtData = {}
+ tgtData.words = {}
+ tgtData.features = {}
+
+ local ignored = {}
+ local indexMap = {}
+ local index = 1
+
+ for b = 1, #src do
+ if (src[b].words and #src[b].words == 0) or (tgt[b].words and #tgt[b].words == 0) then
+ table.insert(ignored, b)
+ else
+ indexMap[index] = b
+ index = index + 1
+
+ if self.dicts.src then
+ table.insert(srcData.words,
+ self.dicts.src.words:convertToIdx(src[b].words, onmt.Constants.UNK_WORD))
+ if #self.dicts.src.features > 0 then
+ table.insert(srcData.features,
+ onmt.utils.Features.generateSource(self.dicts.src.features, src[b].features))
+ end
+ else
+ table.insert(srcData.words,onmt.utils.Cuda.convert(src[b].vectors))
+ end
+
+ if self.dicts.tgt then
+ table.insert(tgtData.words,
+ self.dicts.tgt.words:convertToIdx(tgt[b].words,
+ onmt.Constants.UNK_WORD,
+ onmt.Constants.BOS_WORD,
+ onmt.Constants.EOS_WORD))
+
+ if #self.dicts.tgt.features > 0 then
+ table.insert(tgtData.features,
+ onmt.utils.Features.generateTarget(self.dicts.tgt.features, tgt[b].features))
+ end
+ else
+ table.insert(tgtData.words,onmt.utils.Cuda.convert(tgt[b].vectors))
+ end
+ end
+ end
+
+ return onmt.data.Dataset.new(srcData, tgtData), ignored, indexMap
+end
+
function Tagger:buildTargetWords(pred)
local tokens = self.dicts.tgt.words:convertToLabels(pred, onmt.Constants.EOS)
@@ -124,40 +189,6 @@ function Tagger:buildTargetFeatures(predFeats)
return feats
end
-function Tagger:tagBatch(batch)
- local pred = {}
- local feats = {}
- for _ = 1, batch.size do
- table.insert(pred, {})
- table.insert(feats, {})
- end
- local _, context = self.model.models.encoder:forward(batch)
-
- for t = 1, batch.sourceLength do
- local out = self.model.models.generator:forward(context:select(2, t))
- if type(out[1]) == 'table' then
- out = out[1]
- end
- local _, best = out[1]:max(2)
- for b = 1, batch.size do
- if t > batch.sourceLength - batch.sourceSize[b] then
- pred[b][t - batch.sourceLength + batch.sourceSize[b]] = best[b][1]
- feats[b][t - batch.sourceLength + batch.sourceSize[b]] = {}
- end
- end
- for j = 2, #out do
- _, best = out[j]:max(2)
- for b = 1, batch.size do
- if t > batch.sourceLength - batch.sourceSize[b] then
- feats[b][t - batch.sourceLength + batch.sourceSize[b]][j - 1] = best[b][1]
- end
- end
- end
- end
-
- return pred, feats
-end
-
--[[ Tag a batch of source sequences.
Parameters:
@@ -180,7 +211,7 @@ function Tagger:tag(src)
if data:batchCount() > 0 then
local batch = onmt.utils.Cuda.convert(data:getBatch())
- local pred, predFeats = self:tagBatch(batch)
+ local pred, predFeats = self.model:tagBatch(batch)
for b = 1, batch.size do
results[b] = {}
@@ -196,4 +227,15 @@ function Tagger:tag(src)
return results
end
+function Tagger:computeLosses(src, tgt)
+ local losses = {}
+ for b=1,#src do
+ local data, _ = self:buildGoldData({src[b]}, {tgt[b]})
+ local batch = onmt.utils.Cuda.convert(data:getBatch())
+ local loss = self.model:forwardComputeLoss(batch)
+ table.insert(losses, loss)
+ end
+ return losses
+end
+
return Tagger