Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndreas Köpf <andreas.koepf@xamla.com>2016-01-26 01:23:02 +0300
committerAndreas Köpf <andreas.koepf@xamla.com>2016-02-01 21:54:07 +0300
commit68f61cf984f582ed3d4ece5d9e9073f19e57345e (patch)
treeaaf84f4efaf3ee1717fff2b3ad1a35fadb19dedf /SpatialSoftMax.lua
parentab95570bc4a26a515c30d80b37dcd68af102cb9f (diff)
Add THNN conversion of {RReLU, Sigmoid, SmoothL1Criterion,SoftMax, SoftPlus}
Diffstat (limited to 'SpatialSoftMax.lua')
-rw-r--r--SpatialSoftMax.lua14
1 files changed, 12 insertions, 2 deletions
diff --git a/SpatialSoftMax.lua b/SpatialSoftMax.lua
index bbf6933..56f0b40 100644
--- a/SpatialSoftMax.lua
+++ b/SpatialSoftMax.lua
@@ -1,9 +1,19 @@
local SpatialSoftMax, _ = torch.class('nn.SpatialSoftMax', 'nn.Module')
function SpatialSoftMax:updateOutput(input)
- return input.nn.SoftMax_updateOutput(self, input)
+ input.THNN.SoftMax_updateOutput(
+ input:cdata(),
+ self.output:cdata()
+ )
+ return self.output
end
function SpatialSoftMax:updateGradInput(input, gradOutput)
- return input.nn.SoftMax_updateGradInput(self, input, gradOutput)
+ input.THNN.SoftMax_updateGradInput(
+ input:cdata(),
+ gradOutput:cdata(),
+ self.gradInput:cdata(),
+ self.output:cdata()
+ )
+ return self.gradInput
end