diff options
author | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-25 22:23:17 +0400 |
---|---|---|
committer | Sergio Gomez <sergomezcol@gmail.com> | 2014-06-25 22:29:28 +0400 |
commit | d85c2ce7ab24a699855f41b1919a74f81def47cd (patch) | |
tree | 27fcdcfcf13c63b06610943e82a8798ef78d260c /SplitTable.lua | |
parent | ea9cc1df751ddb144c08a13aab3add1ab0ce90a1 (diff) |
Add minibatch support for nn.JoinTable and nn.SplitTable
The method setNumInputDims allows forwarding both minibatch and
non-minibatch tensors through the same module.
If this method is not used, the behaviour of these modules is the
same as before.
Diffstat (limited to 'SplitTable.lua')
-rw-r--r-- | SplitTable.lua | 29 |
1 files changed, 22 insertions, 7 deletions
diff --git a/SplitTable.lua b/SplitTable.lua index d2c690e..b69e9ee 100644 --- a/SplitTable.lua +++ b/SplitTable.lua @@ -2,29 +2,44 @@ local SplitTable, parent = torch.class('nn.SplitTable', 'nn.Module') function SplitTable:__init(dimension) parent.__init(self) - self.modules = {} + self.modules = {} self.dimension = dimension + self.nInputDims = nil +end + +-- Sets the expected number of dimensions +-- in a non-minibatch input. +function SplitTable:setNumInputDims(nInputDims) + self.nInputDims = nInputDims + return self end function SplitTable:updateOutput(input) - local currentOutput= {}; - local slices = input:size(self.dimension) + local dimension = self.dimension + if self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + local currentOutput= {} + local slices = input:size(dimension) for i=1,slices do - currentOutput[#currentOutput+1] = input:select(self.dimension,i) + currentOutput[#currentOutput+1] = input:select(dimension,i) end self.output = currentOutput return self.output end - function SplitTable:updateGradInput(input, gradOutput) - local slices = input:size(self.dimension) + local dimension = self.dimension + if self.nInputDims and input:dim()==(self.nInputDims+1) then + dimension = dimension + 1 + end + local slices = input:size(dimension) self.gradInput:resizeAs(input) local offset = 1 for i=1,slices do local currentGradInput = gradOutput[i]; - self.gradInput:select(self.dimension,i):copy(currentGradInput) + self.gradInput:select(dimension,i):copy(currentGradInput) end return self.gradInput end |