diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-27 09:50:33 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-27 09:50:33 +0400 |
commit | 7998ade5d524c5e4938a108e3c3c863606ceaf29 (patch) | |
tree | 91a0de01c91c370330bf3b014e099b69ea2c481f /MultiSoftMax.lua | |
parent | 958e3ba2cff6450b51646bbd46878e1effe170c7 (diff) |
MultiSoftMax works
Diffstat (limited to 'MultiSoftMax.lua')
-rw-r--r-- | MultiSoftMax.lua | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/MultiSoftMax.lua b/MultiSoftMax.lua index f08588f..20db4b8 100644 --- a/MultiSoftMax.lua +++ b/MultiSoftMax.lua @@ -4,7 +4,7 @@ ------------------------------------------------------------------------ local MultiSoftMax, parent = torch.class('nn.MultiSoftMax', 'nn.Module') -function MultiSoftMax.__init() +function MultiSoftMax.__init(self) parent.__init(self) self._input = torch.Tensor() self._output = torch.Tensor() @@ -35,7 +35,10 @@ function MultiSoftMax:updateGradInput(input, gradOutput) self._gradOutput:view(gradOutput, input:size(1)*input:size(2), input:size(3)) local gradInput = self.gradInput self.gradInput = self._gradInput + local output = self.output + self.output = self._output input.nn.SoftMax_updateGradInput(self, self._input, self._gradOutput) self.gradInput = gradInput:viewAs(self.gradInput, input) + self.output = output return self.gradInput end |