diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-04-02 23:24:29 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-04-02 23:24:29 +0400 |
commit | bf7a27b2a6f1de76b72fe5c3b3b75bd92bf856ab (patch) | |
tree | 95a9151b57caf8ebed7c8492696d94c4d36c1ba6 | |
parent | de5dc53c9339ad77aee8581904e69ee879a8fe3e (diff) |
KL Divergence criterion, for continuous distributions.
This is the analogous of ClassNLL, but for non 1-of-N
distributions.
-rw-r--r-- | DistKLDivCriterion.lua | 64 | ||||
-rw-r--r-- | init.lua | 1 |
2 files changed, 65 insertions, 0 deletions
diff --git a/DistKLDivCriterion.lua b/DistKLDivCriterion.lua new file mode 100644 index 0000000..b6b6216 --- /dev/null +++ b/DistKLDivCriterion.lua @@ -0,0 +1,64 @@ +local DistKLDivCriterion, parent = torch.class('nn.DistKLDivCriterion', 'nn.Criterion') + +local epsilon = 1e-100 + +function DistKLDivCriterion:__init() + parent.__init(self) + self.sizeAverage = true +end + +function DistKLDivCriterion:updateOutput(input, target) + local log = math.log + if input:dim() == 1 then + self.output = 0 + for i = 1,input:size(1) do + local acc = 0 + if target[i] > 0 then + acc = target[i] * (log(target[i]) - input[i]) + end + self.output = self.output + acc + end + elseif input:dim() == 2 then + self.output = 0 + for i=1,target:size(1) do + local tar = target[i] + local inp = input[i] + for i = 1,inp:size(1) do + local acc = 0 + if tar[i] > epsilon then + acc = tar[i] * (log(tar[i]) - inp[i]) + end + self.output = self.output + acc + end + end + if self.sizeAverage then + self.output = self.output / target:size(1) + end + else + error('matrix or vector expected') + end + return self.output +end + +function DistKLDivCriterion:updateGradInput(input, target) + local gradInput = self.gradInput + gradInput:resizeAs(input) + + if input:dim() == 1 then + for i = 1,input:size(1) do + gradInput[i] = -target[i] + end + else + for i=1,target:size(1) do + local tar = target[i] + for i = 1,tar:size(1) do + gradInput[i] = -tar[i] + end + end + if self.sizeAverage then + gradInput:div(target:size(1)) + end + end + + return self.gradInput +end @@ -79,6 +79,7 @@ torch.include('nn', 'MSECriterion.lua') torch.include('nn', 'MarginCriterion.lua') torch.include('nn', 'AbsCriterion.lua') torch.include('nn', 'ClassNLLCriterion.lua') +torch.include('nn', 'DistKLDivCriterion.lua') torch.include('nn', 'MultiCriterion.lua') torch.include('nn', 'L1HingeEmbeddingCriterion.lua') torch.include('nn', 'HingeEmbeddingCriterion.lua') |