Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-11-10 20:09:49 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-11-10 20:09:49 +0400
commit2ee96f5089bb160882c836ed533028f75e3d0c54 (patch)
tree3c21609cedb135a5738ea387f43b6d07ee896424
parent0a91f301dceec8c28a7ba51d06bcf8f616ca5354 (diff)
Added optional distance mode for NLL criterion.
-rw-r--r--DistNLLCriterion.lua18
1 files changed, 16 insertions, 2 deletions
diff --git a/DistNLLCriterion.lua b/DistNLLCriterion.lua
index fedda1b..8f2528a 100644
--- a/DistNLLCriterion.lua
+++ b/DistNLLCriterion.lua
@@ -3,6 +3,7 @@ local DistNLLCriterion, parent = torch.class('nn.DistNLLCriterion', 'nn.Criterio
function DistNLLCriterion:__init()
parent.__init(self)
-- user options
+ self.inputIsADistance = false
self.inputIsProbability = false
self.inputIsLogProbability = false
self.targetIsProbability = false
@@ -10,6 +11,7 @@ function DistNLLCriterion:__init()
self.targetSoftMax = nn.SoftMax()
self.inputLogSoftMax = nn.LogSoftMax()
self.gradLogInput = torch.Tensor()
+ self.input = torch.Tensor()
end
function DistNLLCriterion:normalize(input, target)
@@ -20,13 +22,20 @@ function DistNLLCriterion:normalize(input, target)
self.probTarget = target
end
+ -- flip input if a distance
+ if self.inputIsADistance then
+ self.input:resizeAs(input):copy(input):mul(-1)
+ else
+ self.input = input
+ end
+
-- normalize input
if not self.inputIsLogProbability and not self.inputIsProbability then
- self.logProbInput = self.inputLogSoftMax:forward(input)
+ self.logProbInput = self.inputLogSoftMax:forward(self.input)
elseif not self.inputIsLogProbability then
print('TODO: implement nn.Log()')
else
- self.logProbInput = input
+ self.logProbInput = self.input
end
end
@@ -39,6 +48,11 @@ function DistNLLCriterion:denormalize(input)
else
self.gradInput = self.gradLogInput
end
+
+ -- if input is a distance, then flip gradients back
+ if self.inputIsADistance then
+ self.gradInput:mul(-1)
+ end
end
function DistNLLCriterion:forward(input, target)