diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-30 00:55:12 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-30 00:55:12 +0400 |
commit | 398e21cfaf6a8e4fa47edd9e6de081497e75dda2 (patch) | |
tree | 6d9779b6d7fa47cf518993c2d788429369a58264 /DistNLLCriterion.lua | |
parent | 0b7f3b9cd4e578a96cfc6b50b18bca2f27cc4682 (diff) |
Added DistNLLCriterion, to support Neg Likelihood for distributions.
ClassNLLCriterion only supports simple distributions, e.g. one-of-N.
DistNLLCriterion supports arbitrary distributions.
Diffstat (limited to 'DistNLLCriterion.lua')
-rw-r--r-- | DistNLLCriterion.lua | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/DistNLLCriterion.lua b/DistNLLCriterion.lua new file mode 100644 index 0000000..c0b69e3 --- /dev/null +++ b/DistNLLCriterion.lua @@ -0,0 +1,61 @@ +local DistNLLCriterion, parent = torch.class('nn.DistNLLCriterion', 'nn.Criterion') + +function DistNLLCriterion:__init() + parent.__init(self) + -- user options + self.inputIsProbability = false + self.inputIsLogProbability = false + self.targetIsProbability = false + -- internal + self.targetSoftMax = nn.SoftMax() + self.inputLogSoftMax = nn.LogSoftMax() + self.gradLogInput = torch.Tensor() +end + +function DistNLLCriterion:normalize(input, target) + -- normalize target + if not self.targetIsProbability then + self.probTarget = self.targetSoftMax:forward(target) + else + self.probTarget = target + end + + -- normalize input + if not self.inputIsLogProbability and not self.inputIsProbability then + self.logProbInput = self.inputLogSoftMax:forward(input) + elseif not self.inputIsLogProbability then + print('TODO: implement nn.Log()') + else + self.logProbInput = input + end +end + +function DistNLLCriterion:denormalize(input) + -- denormalize gradients + if not self.inputIsLogProbability and not self.inputIsProbability then + self.gradInput = self.inputLogSoftMax:backward(input, self.gradLogInput) + elseif not self.inputIsLogProbability then + print('TODO: implement nn.Log()') + else + self.gradInput = self.gradLogInput + end +end + +function DistNLLCriterion:forward(input, target) + self:normalize(input, target) + self.output = 0 + for i = 1,input:size(1) do + self.output = self.output - self.logProbInput[i] * self.probTarget[i] + end + return self.output +end + +function DistNLLCriterion:backward(input, target) + self:normalize(input, target) + self.gradLogInput:resizeAs(input) + for i = 1,input:size(1) do + self.gradLogInput[i] = -self.probTarget[i] + end + self:denormalize(input) + return self.gradInput +end |