Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-27 09:50:33 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-27 09:50:33 +0400
commit7998ade5d524c5e4938a108e3c3c863606ceaf29 (patch)
tree91a0de01c91c370330bf3b014e099b69ea2c481f /MultiSoftMax.lua
parent958e3ba2cff6450b51646bbd46878e1effe170c7 (diff)
MultiSoftMax works
Diffstat (limited to 'MultiSoftMax.lua')
-rw-r--r--MultiSoftMax.lua5
1 files changed, 4 insertions, 1 deletions
diff --git a/MultiSoftMax.lua b/MultiSoftMax.lua
index f08588f..20db4b8 100644
--- a/MultiSoftMax.lua
+++ b/MultiSoftMax.lua
@@ -4,7 +4,7 @@
------------------------------------------------------------------------
local MultiSoftMax, parent = torch.class('nn.MultiSoftMax', 'nn.Module')
-function MultiSoftMax.__init()
+function MultiSoftMax.__init(self)
parent.__init(self)
self._input = torch.Tensor()
self._output = torch.Tensor()
@@ -35,7 +35,10 @@ function MultiSoftMax:updateGradInput(input, gradOutput)
self._gradOutput:view(gradOutput, input:size(1)*input:size(2), input:size(3))
local gradInput = self.gradInput
self.gradInput = self._gradInput
+ local output = self.output
+ self.output = self._output
input.nn.SoftMax_updateGradInput(self, self._input, self._gradOutput)
self.gradInput = gradInput:viewAs(self.gradInput, input)
+ self.output = output
return self.gradInput
end