Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-09-23 04:52:29 +0400
committerSoumith Chintala <soumith@gmail.com>2014-10-27 17:05:15 +0300
commit8d7d03ebe72ae507ae716f292778123dc34e04b1 (patch)
treeb6f2516e6855793d7e06484f0a443948320d8551 /DepthConcat.lua
parentc721165632345794bb3f5faf6a3502d830c207b6 (diff)
DepthConcat
Diffstat (limited to 'DepthConcat.lua')
-rw-r--r--DepthConcat.lua96
1 files changed, 96 insertions, 0 deletions
diff --git a/DepthConcat.lua b/DepthConcat.lua
new file mode 100644
index 0000000..70646f4
--- /dev/null
+++ b/DepthConcat.lua
@@ -0,0 +1,96 @@
+------------------------------------------------------------------------
+--[[ DepthConcat ]]--
+-- Concatenates the output of Convolutions along the depth dimension
+-- (nOutputFrame). This is used to implement the DepthConcat layer
+-- of the Going deeper with convolutions paper :
+-- http://arxiv.org/pdf/1409.4842v1.pdf
+-- The normal Concat Module can't be used since the spatial dimensions
+-- of tensors to be concatenated may have different values. To deal with
+-- this, we select the largest spatial dimensions and add zero-padding
+-- around the smaller dimensions.
+------------------------------------------------------------------------
+local DepthConcat, parent = torch.class('nn.DepthConcat', 'nn.Concat')
+
+function DepthConcat:windowNarrow(output, currentOutput, offset)
+ local outputWindow = output:narrow(self.dimension, offset, currentOutput:size(self.dimension))
+ for dim=1,self.size:size(1) do
+ local currentSize = currentOutput:size(dim)
+ if dim ~= self.dimension and self.size[dim] ~= currentSize then
+ -- 5x5 vs 3x3 -> start = [(5-3)/2] + 1 = 2 (1 pad each side)
+ -- 9x9 vs 5x5 -> start = [(9-5)/2] + 1 = 3 (2 pad each side)
+ -- 9x9 vs 4x4 -> start = [(9-4)/2] + 1 = 3.5 (2 pad, 3 pad)
+ local start = math.floor(((self.size[dim] - currentSize) / 2) + 1)
+ outputWindow = outputWindow:narrow(dim, start, currentSize)
+ end
+ end
+ return outputWindow
+end
+
+function DepthConcat:updateOutput(input)
+ local outs = {}
+ for i=1,#self.modules do
+ local currentOutput = self.modules[i]:updateOutput(input)
+ outs[i] = currentOutput
+ if i == 1 then
+ self.size:resize(currentOutput:dim()):copy(currentOutput:size())
+ else
+ self.size[self.dimension] = self.size[self.dimension] + currentOutput:size(self.dimension)
+ for dim=1,self.size:size(1) do
+ if dim ~= self.dimension then
+ -- take the maximum size (shouldn't change anything for batch dim)
+ self.size[dim] = math.max(self.size[dim], currentOutput:size(dim))
+ end
+ end
+ end
+ end
+ self.output:resize(self.size):zero() --zero for padding
+
+ local offset = 1
+ for i,module in ipairs(self.modules) do
+ local currentOutput = outs[i]
+ local outputWindow = self:windowNarrow(self.output, currentOutput, offset)
+ outputWindow:copy(currentOutput)
+ offset = offset + currentOutput:size(self.dimension)
+ end
+ return self.output
+end
+
+function DepthConcat:updateGradInput(input, gradOutput)
+ self.gradInput:resizeAs(input)
+
+ local offset = 1
+ for i,module in ipairs(self.modules) do
+ local currentOutput = module.output
+ local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset)
+ local currentGradInput = module:updateGradInput(input, gradOutputWindow)
+ if i==1 then
+ self.gradInput:copy(currentGradInput)
+ else
+ self.gradInput:add(currentGradInput)
+ end
+ offset = offset + currentOutput:size(self.dimension)
+ end
+ return self.gradInput
+end
+
+function DepthConcat:accGradParameters(input, gradOutput, scale)
+ scale = scale or 1
+ local offset = 1
+ for i,module in ipairs(self.modules) do
+ local currentOutput = module.output
+ local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset)
+ local currentGradInput = module:accGradParameters(input, gradOutputWindow, scale)
+ offset = offset + currentOutput:size(self.dimension)
+ end
+end
+
+function DepthConcat:accUpdateGradParameters(input, gradOutput, lr)
+ local offset = 1
+ for i,module in ipairs(self.modules) do
+ local currentOutput = module.output
+ local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset)
+ local currentGradInput = module:accUpdateGradParameters(input, gradOutputWindow, lr)
+ offset = offset + currentOutput:size(self.dimension)
+ end
+end
+