diff options
author | Soumith Chintala <soumith@gmail.com> | 2017-04-19 18:04:38 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-04-19 18:04:38 +0300 |
commit | cfbae887bbc3f54566905c446cedab21da65293a (patch) | |
tree | 6520b961a3c613d6bcbbc53b3d7050e9038c6ecb | |
parent | 0c35de7a71b9c91108f02eb042af7a7cf657e484 (diff) | |
parent | 8f77d66dd824d9a057338f2590efc2d9302cccf6 (diff) |
Merge pull request #1197 from davidemaz/transposeSetNumInputDims
nn.Transpose setNumInputDims() added
-rw-r--r-- | Transpose.lua | 13 | ||||
-rw-r--r-- | doc/simple.md | 13 |
2 files changed, 23 insertions, 3 deletions
diff --git a/Transpose.lua b/Transpose.lua index 263db60..cceb2b6 100644 --- a/Transpose.lua +++ b/Transpose.lua @@ -7,11 +7,18 @@ local Transpose, parent = torch.class('nn.Transpose', 'nn.Module') function Transpose:__init(...) parent.__init(self) self.permutations = {...} + self.numInputDims = nil +end + +function Transpose:setNumInputDims(numInputDims) + self.numInputDims = numInputDims + return self end function Transpose:updateOutput(input) + local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0 for _,perm in ipairs(self.permutations) do - input = input:transpose(perm[1],perm[2]) + input = input:transpose(perm[1]+offset,perm[2]+offset) end self.output:resizeAs(input):copy(input) return self.output @@ -20,9 +27,9 @@ end function Transpose:updateGradInput(input, gradOutput) for i = #self.permutations,1,-1 do local perm = self.permutations[i] - gradOutput = gradOutput:transpose(perm[1],perm[2]) + local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0 + gradOutput = gradOutput:transpose(perm[1]+offset,perm[2]+offset) end self.gradInput:resizeAs(gradOutput):copy(gradOutput) return self.gradInput end - diff --git a/doc/simple.md b/doc/simple.md index 804727c..9306edf 100644 --- a/doc/simple.md +++ b/doc/simple.md @@ -1244,6 +1244,19 @@ t:transpose(dim1, dim2) t:transpose(dim3, dim4) ``` +The method `setNumInputDims()` allows to specify the expected number of dimensions of the inputs of the modules. This makes it possible to use minibatch inputs. Example: +```lua +b = 5 -- batch size 5 +input = torch.Tensor(b, 2, 4, 3) -- input: b x 2 x 4 x 3 + +m = nn.Transpose({1,3}) +m:forward(input) -- output: 4 x 2 x b x 3 x 1 + +numInputDims = 3 -- input feature map should be the last 3 dims +m = nn.Transpose({1,3}):setNumInputDims(numInputDims) +m:forward(input) -- output: b x 3 x 4 x 2 +``` + <a name="nn.Exp"></a> ## Exp ## |