Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-30 00:55:12 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-30 00:55:12 +0400
commit398e21cfaf6a8e4fa47edd9e6de081497e75dda2 (patch)
tree6d9779b6d7fa47cf518993c2d788429369a58264 /DistNLLCriterion.lua
parent0b7f3b9cd4e578a96cfc6b50b18bca2f27cc4682 (diff)
Added DistNLLCriterion, to support Neg Likelihood for distributions.
ClassNLLCriterion only supports simple distributions, e.g. one-of-N. DistNLLCriterion supports arbitrary distributions.
Diffstat (limited to 'DistNLLCriterion.lua')
-rw-r--r--DistNLLCriterion.lua61
1 files changed, 61 insertions, 0 deletions
diff --git a/DistNLLCriterion.lua b/DistNLLCriterion.lua
new file mode 100644
index 0000000..c0b69e3
--- /dev/null
+++ b/DistNLLCriterion.lua
@@ -0,0 +1,61 @@
+local DistNLLCriterion, parent = torch.class('nn.DistNLLCriterion', 'nn.Criterion')
+
+function DistNLLCriterion:__init()
+ parent.__init(self)
+ -- user options
+ self.inputIsProbability = false
+ self.inputIsLogProbability = false
+ self.targetIsProbability = false
+ -- internal
+ self.targetSoftMax = nn.SoftMax()
+ self.inputLogSoftMax = nn.LogSoftMax()
+ self.gradLogInput = torch.Tensor()
+end
+
+function DistNLLCriterion:normalize(input, target)
+ -- normalize target
+ if not self.targetIsProbability then
+ self.probTarget = self.targetSoftMax:forward(target)
+ else
+ self.probTarget = target
+ end
+
+ -- normalize input
+ if not self.inputIsLogProbability and not self.inputIsProbability then
+ self.logProbInput = self.inputLogSoftMax:forward(input)
+ elseif not self.inputIsLogProbability then
+ print('TODO: implement nn.Log()')
+ else
+ self.logProbInput = input
+ end
+end
+
+function DistNLLCriterion:denormalize(input)
+ -- denormalize gradients
+ if not self.inputIsLogProbability and not self.inputIsProbability then
+ self.gradInput = self.inputLogSoftMax:backward(input, self.gradLogInput)
+ elseif not self.inputIsLogProbability then
+ print('TODO: implement nn.Log()')
+ else
+ self.gradInput = self.gradLogInput
+ end
+end
+
+function DistNLLCriterion:forward(input, target)
+ self:normalize(input, target)
+ self.output = 0
+ for i = 1,input:size(1) do
+ self.output = self.output - self.logProbInput[i] * self.probTarget[i]
+ end
+ return self.output
+end
+
+function DistNLLCriterion:backward(input, target)
+ self:normalize(input, target)
+ self.gradLogInput:resizeAs(input)
+ for i = 1,input:size(1) do
+ self.gradLogInput[i] = -self.probTarget[i]
+ end
+ self:denormalize(input)
+ return self.gradInput
+end