diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-08-30 00:55:12 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-08-30 00:55:12 +0400 |
commit | 398e21cfaf6a8e4fa47edd9e6de081497e75dda2 (patch) | |
tree | 6d9779b6d7fa47cf518993c2d788429369a58264 | |
parent | 0b7f3b9cd4e578a96cfc6b50b18bca2f27cc4682 (diff) |
Added DistNLLCriterion, to support Neg Likelihood for distributions.
ClassNLLCriterion only supports simple distributions, e.g. one-of-N.
DistNLLCriterion supports arbitrary distributions.
-rw-r--r-- | DataList.lua | 13 | ||||
-rw-r--r-- | DistNLLCriterion.lua | 61 | ||||
-rw-r--r-- | init.lua | 1 | ||||
-rw-r--r-- | nnx-1.0-1.rockspec | 1 |
4 files changed, 74 insertions, 2 deletions
diff --git a/DataList.lua b/DataList.lua index 9677022..4922e8b 100644 --- a/DataList.lua +++ b/DataList.lua @@ -13,6 +13,7 @@ function DataList:__init() self.nbClass = 0 self.ClassName = {} self.nbSamples = 0 + self.targetIsProbability = false self.spatialTarget = false end @@ -32,10 +33,18 @@ function DataList:__index__(key) -- create target vector on the fly if self.spatialTarget then - self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):fill(-1) + if self.targetIsProbability then + self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):zero() + else + self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):fill(-1) + end self.datasets[class][elmt][2][class][1][1] = 1 else - self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):fill(-1) + if self.targetIsProbability then + self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):zero() + else + self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):fill(-1) + end self.datasets[class][elmt][2][class] = 1 end diff --git a/DistNLLCriterion.lua b/DistNLLCriterion.lua new file mode 100644 index 0000000..c0b69e3 --- /dev/null +++ b/DistNLLCriterion.lua @@ -0,0 +1,61 @@ +local DistNLLCriterion, parent = torch.class('nn.DistNLLCriterion', 'nn.Criterion') + +function DistNLLCriterion:__init() + parent.__init(self) + -- user options + self.inputIsProbability = false + self.inputIsLogProbability = false + self.targetIsProbability = false + -- internal + self.targetSoftMax = nn.SoftMax() + self.inputLogSoftMax = nn.LogSoftMax() + self.gradLogInput = torch.Tensor() +end + +function DistNLLCriterion:normalize(input, target) + -- normalize target + if not self.targetIsProbability then + self.probTarget = self.targetSoftMax:forward(target) + else + self.probTarget = target + end + + -- normalize input + if not self.inputIsLogProbability and not self.inputIsProbability then + self.logProbInput = self.inputLogSoftMax:forward(input) + elseif not self.inputIsLogProbability then + print('TODO: implement nn.Log()') + else + self.logProbInput = input + end +end + +function DistNLLCriterion:denormalize(input) + -- denormalize gradients + if not self.inputIsLogProbability and not self.inputIsProbability then + self.gradInput = self.inputLogSoftMax:backward(input, self.gradLogInput) + elseif not self.inputIsLogProbability then + print('TODO: implement nn.Log()') + else + self.gradInput = self.gradLogInput + end +end + +function DistNLLCriterion:forward(input, target) + self:normalize(input, target) + self.output = 0 + for i = 1,input:size(1) do + self.output = self.output - self.logProbInput[i] * self.probTarget[i] + end + return self.output +end + +function DistNLLCriterion:backward(input, target) + self:normalize(input, target) + self.gradLogInput:resizeAs(input) + for i = 1,input:size(1) do + self.gradLogInput[i] = -self.probTarget[i] + end + self:denormalize(input) + return self.gradInput +end @@ -93,6 +93,7 @@ torch.include('nnx', 'SpatialColorTransform.lua') -- criterions: torch.include('nnx', 'SuperCriterion.lua') torch.include('nnx', 'SparseCriterion.lua') +torch.include('nnx', 'DistNLLCriterion.lua') torch.include('nnx', 'SpatialMSECriterion.lua') torch.include('nnx', 'SpatialClassNLLCriterion.lua') torch.include('nnx', 'SpatialSparseCriterion.lua') diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec index cfbc571..3af08d0 100644 --- a/nnx-1.0-1.rockspec +++ b/nnx-1.0-1.rockspec @@ -62,6 +62,7 @@ build = { install_files(/lua/nnx init.lua) install_files(/lua/nnx Abs.lua) install_files(/lua/nnx ConfusionMatrix.lua) + install_files(/lua/nnx DistNLLCriterion.lua) install_files(/lua/nnx Logger.lua) install_files(/lua/nnx Probe.lua) install_files(/lua/nnx HardShrink.lua) |