Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominik Grewe <dominikg@google.com>2016-02-05 21:23:58 +0300
committerDominik Grewe <dominikg@google.com>2016-02-05 21:23:58 +0300
commitc77e51ebc6240a21b8427d44670b8e36794b2d3b (patch)
treea5698aa9b59df40b61ed3a1fad8a339652fa2662 /MultiSoftMax.lua
parent5bb2bcbcfbbe65ea33ea4487f631da1fae071de2 (diff)
Use THNN.
Diffstat (limited to 'MultiSoftMax.lua')
-rw-r--r--MultiSoftMax.lua10
1 files changed, 6 insertions, 4 deletions
diff --git a/MultiSoftMax.lua b/MultiSoftMax.lua
index 20db4b8..9eb768a 100644
--- a/MultiSoftMax.lua
+++ b/MultiSoftMax.lua
@@ -14,7 +14,7 @@ end
function MultiSoftMax:updateOutput(input)
if input:dim() == 2 then
- return input.nn.SoftMax_updateOutput(self, input)
+ return input.THNN.SoftMax_updateOutput(input:cdata(), self.output:cdata())
end
if input:dim() ~= 3 then
error"Only supports 2D or 3D inputs"
@@ -22,7 +22,7 @@ function MultiSoftMax:updateOutput(input)
self._input:view(input, input:size(1)*input:size(2), input:size(3))
local output = self.output
self.output = self._output
- input.nn.SoftMax_updateOutput(self, self._input)
+ input.THNN.SoftMax_updateOutput(self._input:cdata(), self.output:cdata())
output:viewAs(self.output, input)
self.output = output
return self.output
@@ -30,14 +30,16 @@ end
function MultiSoftMax:updateGradInput(input, gradOutput)
if input:dim() == 2 then
- return input.nn.SoftMax_updateGradInput(self, input, gradOutput)
+ return input.THNN.SoftMax_updateGradInput(input:cdata(), gradOutput:cdata(),
+ self.gradInput:cdata(), self.output:cdata())
end
self._gradOutput:view(gradOutput, input:size(1)*input:size(2), input:size(3))
local gradInput = self.gradInput
self.gradInput = self._gradInput
local output = self.output
self.output = self._output
- input.nn.SoftMax_updateGradInput(self, self._input, self._gradOutput)
+ input.THNN.SoftMax_updateGradInput(self._input:cdata(), self._gradOutput:cdata(),
+ self.gradInput:cdata(), self.output:cdata())
self.gradInput = gradInput:viewAs(self.gradInput, input)
self.output = output
return self.gradInput