Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-04-02 23:24:29 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-04-02 23:24:29 +0400
commitbf7a27b2a6f1de76b72fe5c3b3b75bd92bf856ab (patch)
tree95a9151b57caf8ebed7c8492696d94c4d36c1ba6
parentde5dc53c9339ad77aee8581904e69ee879a8fe3e (diff)
KL Divergence criterion, for continuous distributions.
This is the analogous of ClassNLL, but for non 1-of-N distributions.
-rw-r--r--DistKLDivCriterion.lua64
-rw-r--r--init.lua1
2 files changed, 65 insertions, 0 deletions
diff --git a/DistKLDivCriterion.lua b/DistKLDivCriterion.lua
new file mode 100644
index 0000000..b6b6216
--- /dev/null
+++ b/DistKLDivCriterion.lua
@@ -0,0 +1,64 @@
+local DistKLDivCriterion, parent = torch.class('nn.DistKLDivCriterion', 'nn.Criterion')
+
+local epsilon = 1e-100
+
+function DistKLDivCriterion:__init()
+ parent.__init(self)
+ self.sizeAverage = true
+end
+
+function DistKLDivCriterion:updateOutput(input, target)
+ local log = math.log
+ if input:dim() == 1 then
+ self.output = 0
+ for i = 1,input:size(1) do
+ local acc = 0
+ if target[i] > 0 then
+ acc = target[i] * (log(target[i]) - input[i])
+ end
+ self.output = self.output + acc
+ end
+ elseif input:dim() == 2 then
+ self.output = 0
+ for i=1,target:size(1) do
+ local tar = target[i]
+ local inp = input[i]
+ for i = 1,inp:size(1) do
+ local acc = 0
+ if tar[i] > epsilon then
+ acc = tar[i] * (log(tar[i]) - inp[i])
+ end
+ self.output = self.output + acc
+ end
+ end
+ if self.sizeAverage then
+ self.output = self.output / target:size(1)
+ end
+ else
+ error('matrix or vector expected')
+ end
+ return self.output
+end
+
+function DistKLDivCriterion:updateGradInput(input, target)
+ local gradInput = self.gradInput
+ gradInput:resizeAs(input)
+
+ if input:dim() == 1 then
+ for i = 1,input:size(1) do
+ gradInput[i] = -target[i]
+ end
+ else
+ for i=1,target:size(1) do
+ local tar = target[i]
+ for i = 1,tar:size(1) do
+ gradInput[i] = -tar[i]
+ end
+ end
+ if self.sizeAverage then
+ gradInput:div(target:size(1))
+ end
+ end
+
+ return self.gradInput
+end
diff --git a/init.lua b/init.lua
index 4fe448e..d589eec 100644
--- a/init.lua
+++ b/init.lua
@@ -79,6 +79,7 @@ torch.include('nn', 'MSECriterion.lua')
torch.include('nn', 'MarginCriterion.lua')
torch.include('nn', 'AbsCriterion.lua')
torch.include('nn', 'ClassNLLCriterion.lua')
+torch.include('nn', 'DistKLDivCriterion.lua')
torch.include('nn', 'MultiCriterion.lua')
torch.include('nn', 'L1HingeEmbeddingCriterion.lua')
torch.include('nn', 'HingeEmbeddingCriterion.lua')