Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordavidemaz <davidemaz@gmail.com>2017-04-19 15:11:29 +0300
committerdavidemaz <davidemaz@gmail.com>2017-04-19 15:11:29 +0300
commitaafe611dff1447915d0d1075a41a2b9c52f1763c (patch)
treeed29c6114645171aaafad8e8fe9893f68f5bb58e /Transpose.lua
parent0c35de7a71b9c91108f02eb042af7a7cf657e484 (diff)
Added setNumInputDims to nn.Transpose
Diffstat (limited to 'Transpose.lua')
-rw-r--r--Transpose.lua13
1 files changed, 10 insertions, 3 deletions
diff --git a/Transpose.lua b/Transpose.lua
index 263db60..cceb2b6 100644
--- a/Transpose.lua
+++ b/Transpose.lua
@@ -7,11 +7,18 @@ local Transpose, parent = torch.class('nn.Transpose', 'nn.Module')
function Transpose:__init(...)
parent.__init(self)
self.permutations = {...}
+ self.numInputDims = nil
+end
+
+function Transpose:setNumInputDims(numInputDims)
+ self.numInputDims = numInputDims
+ return self
end
function Transpose:updateOutput(input)
+ local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0
for _,perm in ipairs(self.permutations) do
- input = input:transpose(perm[1],perm[2])
+ input = input:transpose(perm[1]+offset,perm[2]+offset)
end
self.output:resizeAs(input):copy(input)
return self.output
@@ -20,9 +27,9 @@ end
function Transpose:updateGradInput(input, gradOutput)
for i = #self.permutations,1,-1 do
local perm = self.permutations[i]
- gradOutput = gradOutput:transpose(perm[1],perm[2])
+ local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0
+ gradOutput = gradOutput:transpose(perm[1]+offset,perm[2]+offset)
end
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
return self.gradInput
end
-