diff options
author | davidemaz <davidemaz@gmail.com> | 2017-04-19 15:11:29 +0300 |
---|---|---|
committer | davidemaz <davidemaz@gmail.com> | 2017-04-19 15:11:29 +0300 |
commit | aafe611dff1447915d0d1075a41a2b9c52f1763c (patch) | |
tree | ed29c6114645171aaafad8e8fe9893f68f5bb58e /Transpose.lua | |
parent | 0c35de7a71b9c91108f02eb042af7a7cf657e484 (diff) |
Added setNumInputDims to nn.Transpose
Diffstat (limited to 'Transpose.lua')
-rw-r--r-- | Transpose.lua | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/Transpose.lua b/Transpose.lua index 263db60..cceb2b6 100644 --- a/Transpose.lua +++ b/Transpose.lua @@ -7,11 +7,18 @@ local Transpose, parent = torch.class('nn.Transpose', 'nn.Module') function Transpose:__init(...) parent.__init(self) self.permutations = {...} + self.numInputDims = nil +end + +function Transpose:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self end function Transpose:updateOutput(input) + local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0 for _,perm in ipairs(self.permutations) do - input = input:transpose(perm[1],perm[2]) + input = input:transpose(perm[1]+offset,perm[2]+offset) end self.output:resizeAs(input):copy(input) return self.output @@ -20,9 +27,9 @@ end function Transpose:updateGradInput(input, gradOutput) for i = #self.permutations,1,-1 do local perm = self.permutations[i] - gradOutput = gradOutput:transpose(perm[1],perm[2]) + local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0 + gradOutput = gradOutput:transpose(perm[1]+offset,perm[2]+offset) end self.gradInput:resizeAs(gradOutput):copy(gradOutput) return self.gradInput end - |