Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2017-04-19 18:04:38 +0300
committerGitHub <noreply@github.com>2017-04-19 18:04:38 +0300
commitcfbae887bbc3f54566905c446cedab21da65293a (patch)
tree6520b961a3c613d6bcbbc53b3d7050e9038c6ecb
parent0c35de7a71b9c91108f02eb042af7a7cf657e484 (diff)
parent8f77d66dd824d9a057338f2590efc2d9302cccf6 (diff)
Merge pull request #1197 from davidemaz/transposeSetNumInputDims
nn.Transpose setNumInputDims() added
-rw-r--r--Transpose.lua13
-rw-r--r--doc/simple.md13
2 files changed, 23 insertions, 3 deletions
diff --git a/Transpose.lua b/Transpose.lua
index 263db60..cceb2b6 100644
--- a/Transpose.lua
+++ b/Transpose.lua
@@ -7,11 +7,18 @@ local Transpose, parent = torch.class('nn.Transpose', 'nn.Module')
function Transpose:__init(...)
parent.__init(self)
self.permutations = {...}
+ self.numInputDims = nil
+end
+
+function Transpose:setNumInputDims(numInputDims)
+ self.numInputDims = numInputDims
+ return self
end
function Transpose:updateOutput(input)
+ local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0
for _,perm in ipairs(self.permutations) do
- input = input:transpose(perm[1],perm[2])
+ input = input:transpose(perm[1]+offset,perm[2]+offset)
end
self.output:resizeAs(input):copy(input)
return self.output
@@ -20,9 +27,9 @@ end
function Transpose:updateGradInput(input, gradOutput)
for i = #self.permutations,1,-1 do
local perm = self.permutations[i]
- gradOutput = gradOutput:transpose(perm[1],perm[2])
+ local offset = self.numInputDims and input:nDimension()-self.numInputDims or 0
+ gradOutput = gradOutput:transpose(perm[1]+offset,perm[2]+offset)
end
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
return self.gradInput
end
-
diff --git a/doc/simple.md b/doc/simple.md
index 804727c..9306edf 100644
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -1244,6 +1244,19 @@ t:transpose(dim1, dim2)
t:transpose(dim3, dim4)
```
+The method `setNumInputDims()` allows to specify the expected number of dimensions of the inputs of the modules. This makes it possible to use minibatch inputs. Example:
+```lua
+b = 5 -- batch size 5
+input = torch.Tensor(b, 2, 4, 3) -- input: b x 2 x 4 x 3
+
+m = nn.Transpose({1,3})
+m:forward(input) -- output: 4 x 2 x b x 3 x 1
+
+numInputDims = 3 -- input feature map should be the last 3 dims
+m = nn.Transpose({1,3}):setNumInputDims(numInputDims)
+m:forward(input) -- output: b x 3 x 4 x 2
+```
+
<a name="nn.Exp"></a>
## Exp ##