diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-11-10 20:09:49 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-11-10 20:09:49 +0400 |
commit | 2ee96f5089bb160882c836ed533028f75e3d0c54 (patch) | |
tree | 3c21609cedb135a5738ea387f43b6d07ee896424 | |
parent | 0a91f301dceec8c28a7ba51d06bcf8f616ca5354 (diff) |
Added optional distance mode for NLL criterion.
-rw-r--r-- | DistNLLCriterion.lua | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/DistNLLCriterion.lua b/DistNLLCriterion.lua index fedda1b..8f2528a 100644 --- a/DistNLLCriterion.lua +++ b/DistNLLCriterion.lua @@ -3,6 +3,7 @@ local DistNLLCriterion, parent = torch.class('nn.DistNLLCriterion', 'nn.Criterio function DistNLLCriterion:__init() parent.__init(self) -- user options + self.inputIsADistance = false self.inputIsProbability = false self.inputIsLogProbability = false self.targetIsProbability = false @@ -10,6 +11,7 @@ function DistNLLCriterion:__init() self.targetSoftMax = nn.SoftMax() self.inputLogSoftMax = nn.LogSoftMax() self.gradLogInput = torch.Tensor() + self.input = torch.Tensor() end function DistNLLCriterion:normalize(input, target) @@ -20,13 +22,20 @@ function DistNLLCriterion:normalize(input, target) self.probTarget = target end + -- flip input if a distance + if self.inputIsADistance then + self.input:resizeAs(input):copy(input):mul(-1) + else + self.input = input + end + -- normalize input if not self.inputIsLogProbability and not self.inputIsProbability then - self.logProbInput = self.inputLogSoftMax:forward(input) + self.logProbInput = self.inputLogSoftMax:forward(self.input) elseif not self.inputIsLogProbability then print('TODO: implement nn.Log()') else - self.logProbInput = input + self.logProbInput = self.input end end @@ -39,6 +48,11 @@ function DistNLLCriterion:denormalize(input) else self.gradInput = self.gradLogInput end + + -- if input is a distance, then flip gradients back + if self.inputIsADistance then + self.gradInput:mul(-1) + end end function DistNLLCriterion:forward(input, target) |