diff options
author | Dominik Grewe <dominikg@google.com> | 2016-02-05 21:23:58 +0300 |
---|---|---|
committer | Dominik Grewe <dominikg@google.com> | 2016-02-05 21:23:58 +0300 |
commit | c77e51ebc6240a21b8427d44670b8e36794b2d3b (patch) | |
tree | a5698aa9b59df40b61ed3a1fad8a339652fa2662 /MultiSoftMax.lua | |
parent | 5bb2bcbcfbbe65ea33ea4487f631da1fae071de2 (diff) |
Use THNN.
Diffstat (limited to 'MultiSoftMax.lua')
-rw-r--r-- | MultiSoftMax.lua | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/MultiSoftMax.lua b/MultiSoftMax.lua index 20db4b8..9eb768a 100644 --- a/MultiSoftMax.lua +++ b/MultiSoftMax.lua @@ -14,7 +14,7 @@ end function MultiSoftMax:updateOutput(input) if input:dim() == 2 then - return input.nn.SoftMax_updateOutput(self, input) + return input.THNN.SoftMax_updateOutput(input:cdata(), self.output:cdata()) end if input:dim() ~= 3 then error"Only supports 2D or 3D inputs" @@ -22,7 +22,7 @@ function MultiSoftMax:updateOutput(input) self._input:view(input, input:size(1)*input:size(2), input:size(3)) local output = self.output self.output = self._output - input.nn.SoftMax_updateOutput(self, self._input) + input.THNN.SoftMax_updateOutput(self._input:cdata(), self.output:cdata()) output:viewAs(self.output, input) self.output = output return self.output @@ -30,14 +30,16 @@ end function MultiSoftMax:updateGradInput(input, gradOutput) if input:dim() == 2 then - return input.nn.SoftMax_updateGradInput(self, input, gradOutput) + return input.THNN.SoftMax_updateGradInput(input:cdata(), gradOutput:cdata(), + self.gradInput:cdata(), self.output:cdata()) end self._gradOutput:view(gradOutput, input:size(1)*input:size(2), input:size(3)) local gradInput = self.gradInput self.gradInput = self._gradInput local output = self.output self.output = self._output - input.nn.SoftMax_updateGradInput(self, self._input, self._gradOutput) + input.THNN.SoftMax_updateGradInput(self._input:cdata(), self._gradOutput:cdata(), + self.gradInput:cdata(), self.output:cdata()) self.gradInput = gradInput:viewAs(self.gradInput, input) self.output = output return self.gradInput |