Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'DistNLLCriterion.lua')
-rw-r--r--DistNLLCriterion.lua30
1 files changed, 5 insertions, 25 deletions
diff --git a/DistNLLCriterion.lua b/DistNLLCriterion.lua
index 01290c4..22204fc 100644
--- a/DistNLLCriterion.lua
+++ b/DistNLLCriterion.lua
@@ -17,7 +17,7 @@ end
function DistNLLCriterion:normalize(input, target)
-- normalize target
if not self.targetIsProbability then
- self.probTarget = self.targetSoftMax:forward(target)
+ self.probTarget = self.targetSoftMax:updateOutput(target)
else
self.probTarget = target
end
@@ -31,7 +31,7 @@ function DistNLLCriterion:normalize(input, target)
-- normalize input
if not self.inputIsLogProbability and not self.inputIsProbability then
- self.logProbInput = self.inputLogSoftMax:forward(self.input)
+ self.logProbInput = self.inputLogSoftMax:updateOutput(self.input)
elseif not self.inputIsLogProbability then
print('TODO: implement nn.Log()')
else
@@ -42,7 +42,7 @@ end
function DistNLLCriterion:denormalize()
-- denormalize gradients
if not self.inputIsLogProbability and not self.inputIsProbability then
- self.gradInput = self.inputLogSoftMax:backward(self.input, self.gradLogInput)
+ self.gradInput = self.inputLogSoftMax:updateGradInput(self.input, self.gradLogInput)
elseif not self.inputIsLogProbability then
print('TODO: implement nn.Log()')
else
@@ -55,7 +55,7 @@ function DistNLLCriterion:denormalize()
end
end
-function DistNLLCriterion:forward(input, target)
+function DistNLLCriterion:updateOutput(input, target)
self:normalize(input, target)
self.output = 0
for i = 1,input:size(1) do
@@ -64,7 +64,7 @@ function DistNLLCriterion:forward(input, target)
return self.output
end
-function DistNLLCriterion:backward(input, target)
+function DistNLLCriterion:updateGradInput(input, target)
self:normalize(input, target)
self.gradLogInput:resizeAs(input)
for i = 1,input:size(1) do
@@ -73,23 +73,3 @@ function DistNLLCriterion:backward(input, target)
self:denormalize()
return self.gradInput
end
-
-function DistNLLCriterion:write(file)
- parent.write(self, file)
- file:writeBool(self.inputIsProbability)
- file:writeBool(self.inputIsLogProbability)
- file:writeBool(self.targetIsProbability)
- file:writeObject(self.targetSoftMax)
- file:writeObject(self.inputLogSoftMax)
- file:writeObject(self.gradLogInput)
-end
-
-function DistNLLCriterion:read(file)
- parent.read(self, file)
- self.inputIsProbability = file:readBool()
- self.inputIsLogProbability = file:readBool()
- self.targetIsProbability = file:readBool()
- self.targetSoftMax = file:readObject()
- self.inputLogSoftMax = file:readObject()
- self.gradLogInput = file:readObject()
-end