Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergio Gomez <sergomezcol@gmail.com>2014-06-25 22:23:17 +0400
committerSergio Gomez <sergomezcol@gmail.com>2014-06-25 22:29:28 +0400
commitd85c2ce7ab24a699855f41b1919a74f81def47cd (patch)
tree27fcdcfcf13c63b06610943e82a8798ef78d260c /SplitTable.lua
parentea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff)
Add minibatch support for nn.JoinTable and nn.SplitTable
The method setNumInputDims allows forwarding both minibatch and non-minibatch tensors through the same module. If this method is not used, the behaviour of these modules is the same as before.
Diffstat (limited to 'SplitTable.lua')
-rw-r--r--SplitTable.lua29
1 files changed, 22 insertions, 7 deletions
diff --git a/SplitTable.lua b/SplitTable.lua
index d2c690e..b69e9ee 100644
--- a/SplitTable.lua
+++ b/SplitTable.lua
@@ -2,29 +2,44 @@ local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module')
function SplitTable:__init(dimension)
parent.__init(self)
- self.modules = {}
+ self.modules = {}
self.dimension = dimension
+ self.nInputDims = nil
+end
+
+-- Sets the expected number of dimensions
+-- in a non-minibatch input.
+function SplitTable:setNumInputDims(nInputDims)
+ self.nInputDims = nInputDims
+ return self
end
function SplitTable:updateOutput(input)
- local currentOutput= {};
- local slices = input:size(self.dimension)
+ local dimension = self.dimension
+ if self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+ local currentOutput= {}
+ local slices = input:size(dimension)
for i=1,slices do
- currentOutput[#currentOutput+1] = input:select(self.dimension,i)
+ currentOutput[#currentOutput+1] = input:select(dimension,i)
end
self.output = currentOutput
return self.output
end
-
function SplitTable:updateGradInput(input, gradOutput)
- local slices = input:size(self.dimension)
+ local dimension = self.dimension
+ if self.nInputDims and input:dim()==(self.nInputDims+1) then
+ dimension = dimension + 1
+ end
+ local slices = input:size(dimension)
self.gradInput:resizeAs(input)
local offset = 1
for i=1,slices do
local currentGradInput = gradOutput[i];
- self.gradInput:select(self.dimension,i):copy(currentGradInput)
+ self.gradInput:select(dimension,i):copy(currentGradInput)
end
return self.gradInput
end