Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-08-30 00:55:12 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-08-30 00:55:12 +0400
commit398e21cfaf6a8e4fa47edd9e6de081497e75dda2 (patch)
tree6d9779b6d7fa47cf518993c2d788429369a58264
parent0b7f3b9cd4e578a96cfc6b50b18bca2f27cc4682 (diff)
Added DistNLLCriterion, to support Neg Likelihood for distributions.
ClassNLLCriterion only supports simple distributions, e.g. one-of-N. DistNLLCriterion supports arbitrary distributions.
-rw-r--r--DataList.lua13
-rw-r--r--DistNLLCriterion.lua61
-rw-r--r--init.lua1
-rw-r--r--nnx-1.0-1.rockspec1
4 files changed, 74 insertions, 2 deletions
diff --git a/DataList.lua b/DataList.lua
index 9677022..4922e8b 100644
--- a/DataList.lua
+++ b/DataList.lua
@@ -13,6 +13,7 @@ function DataList:__init()
self.nbClass = 0
self.ClassName = {}
self.nbSamples = 0
+ self.targetIsProbability = false
self.spatialTarget = false
end
@@ -32,10 +33,18 @@ function DataList:__index__(key)
-- create target vector on the fly
if self.spatialTarget then
- self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):fill(-1)
+ if self.targetIsProbability then
+ self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):zero()
+ else
+ self.datasets[class][elmt][2] = torch.Tensor(self.nbClass,1,1):fill(-1)
+ end
self.datasets[class][elmt][2][class][1][1] = 1
else
- self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):fill(-1)
+ if self.targetIsProbability then
+ self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):zero()
+ else
+ self.datasets[class][elmt][2] = torch.Tensor(self.nbClass):fill(-1)
+ end
self.datasets[class][elmt][2][class] = 1
end
diff --git a/DistNLLCriterion.lua b/DistNLLCriterion.lua
new file mode 100644
index 0000000..c0b69e3
--- /dev/null
+++ b/DistNLLCriterion.lua
@@ -0,0 +1,61 @@
+local DistNLLCriterion, parent = torch.class('nn.DistNLLCriterion', 'nn.Criterion')
+
+function DistNLLCriterion:__init()
+ parent.__init(self)
+ -- user options
+ self.inputIsProbability = false
+ self.inputIsLogProbability = false
+ self.targetIsProbability = false
+ -- internal
+ self.targetSoftMax = nn.SoftMax()
+ self.inputLogSoftMax = nn.LogSoftMax()
+ self.gradLogInput = torch.Tensor()
+end
+
+function DistNLLCriterion:normalize(input, target)
+ -- normalize target
+ if not self.targetIsProbability then
+ self.probTarget = self.targetSoftMax:forward(target)
+ else
+ self.probTarget = target
+ end
+
+ -- normalize input
+ if not self.inputIsLogProbability and not self.inputIsProbability then
+ self.logProbInput = self.inputLogSoftMax:forward(input)
+ elseif not self.inputIsLogProbability then
+ print('TODO: implement nn.Log()')
+ else
+ self.logProbInput = input
+ end
+end
+
+function DistNLLCriterion:denormalize(input)
+ -- denormalize gradients
+ if not self.inputIsLogProbability and not self.inputIsProbability then
+ self.gradInput = self.inputLogSoftMax:backward(input, self.gradLogInput)
+ elseif not self.inputIsLogProbability then
+ print('TODO: implement nn.Log()')
+ else
+ self.gradInput = self.gradLogInput
+ end
+end
+
+function DistNLLCriterion:forward(input, target)
+ self:normalize(input, target)
+ self.output = 0
+ for i = 1,input:size(1) do
+ self.output = self.output - self.logProbInput[i] * self.probTarget[i]
+ end
+ return self.output
+end
+
+function DistNLLCriterion:backward(input, target)
+ self:normalize(input, target)
+ self.gradLogInput:resizeAs(input)
+ for i = 1,input:size(1) do
+ self.gradLogInput[i] = -self.probTarget[i]
+ end
+ self:denormalize(input)
+ return self.gradInput
+end
diff --git a/init.lua b/init.lua
index 20246bc..ea44de0 100644
--- a/init.lua
+++ b/init.lua
@@ -93,6 +93,7 @@ torch.include('nnx', 'SpatialColorTransform.lua')
-- criterions:
torch.include('nnx', 'SuperCriterion.lua')
torch.include('nnx', 'SparseCriterion.lua')
+torch.include('nnx', 'DistNLLCriterion.lua')
torch.include('nnx', 'SpatialMSECriterion.lua')
torch.include('nnx', 'SpatialClassNLLCriterion.lua')
torch.include('nnx', 'SpatialSparseCriterion.lua')
diff --git a/nnx-1.0-1.rockspec b/nnx-1.0-1.rockspec
index cfbc571..3af08d0 100644
--- a/nnx-1.0-1.rockspec
+++ b/nnx-1.0-1.rockspec
@@ -62,6 +62,7 @@ build = {
install_files(/lua/nnx init.lua)
install_files(/lua/nnx Abs.lua)
install_files(/lua/nnx ConfusionMatrix.lua)
+ install_files(/lua/nnx DistNLLCriterion.lua)
install_files(/lua/nnx Logger.lua)
install_files(/lua/nnx Probe.lua)
install_files(/lua/nnx HardShrink.lua)