diff options
Diffstat (limited to 'DistNLLCriterion.lua')
-rw-r--r-- | DistNLLCriterion.lua | 30 |
1 files changed, 5 insertions, 25 deletions
diff --git a/DistNLLCriterion.lua b/DistNLLCriterion.lua index 01290c4..22204fc 100644 --- a/DistNLLCriterion.lua +++ b/DistNLLCriterion.lua @@ -17,7 +17,7 @@ end function DistNLLCriterion:normalize(input, target) -- normalize target if not self.targetIsProbability then - self.probTarget = self.targetSoftMax:forward(target) + self.probTarget = self.targetSoftMax:updateOutput(target) else self.probTarget = target end @@ -31,7 +31,7 @@ function DistNLLCriterion:normalize(input, target) -- normalize input if not self.inputIsLogProbability and not self.inputIsProbability then - self.logProbInput = self.inputLogSoftMax:forward(self.input) + self.logProbInput = self.inputLogSoftMax:updateOutput(self.input) elseif not self.inputIsLogProbability then print('TODO: implement nn.Log()') else @@ -42,7 +42,7 @@ end function DistNLLCriterion:denormalize() -- denormalize gradients if not self.inputIsLogProbability and not self.inputIsProbability then - self.gradInput = self.inputLogSoftMax:backward(self.input, self.gradLogInput) + self.gradInput = self.inputLogSoftMax:updateGradInput(self.input, self.gradLogInput) elseif not self.inputIsLogProbability then print('TODO: implement nn.Log()') else @@ -55,7 +55,7 @@ function DistNLLCriterion:denormalize() end end -function DistNLLCriterion:forward(input, target) +function DistNLLCriterion:updateOutput(input, target) self:normalize(input, target) self.output = 0 for i = 1,input:size(1) do @@ -64,7 +64,7 @@ function DistNLLCriterion:forward(input, target) return self.output end -function DistNLLCriterion:backward(input, target) +function DistNLLCriterion:updateGradInput(input, target) self:normalize(input, target) self.gradLogInput:resizeAs(input) for i = 1,input:size(1) do @@ -73,23 +73,3 @@ function DistNLLCriterion:backward(input, target) self:denormalize() return self.gradInput end - -function DistNLLCriterion:write(file) - parent.write(self, file) - file:writeBool(self.inputIsProbability) - file:writeBool(self.inputIsLogProbability) - file:writeBool(self.targetIsProbability) - file:writeObject(self.targetSoftMax) - file:writeObject(self.inputLogSoftMax) - file:writeObject(self.gradLogInput) -end - -function DistNLLCriterion:read(file) - parent.read(self, file) - self.inputIsProbability = file:readBool() - self.inputIsLogProbability = file:readBool() - self.targetIsProbability = file:readBool() - self.targetSoftMax = file:readObject() - self.inputLogSoftMax = file:readObject() - self.gradLogInput = file:readObject() -end |