diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-09-23 04:52:29 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2014-10-27 17:05:15 +0300 |
commit | 8d7d03ebe72ae507ae716f292778123dc34e04b1 (patch) | |
tree | b6f2516e6855793d7e06484f0a443948320d8551 /DepthConcat.lua | |
parent | c721165632345794bb3f5faf6a3502d830c207b6 (diff) |
DepthConcat
Diffstat (limited to 'DepthConcat.lua')
-rw-r--r-- | DepthConcat.lua | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/DepthConcat.lua b/DepthConcat.lua new file mode 100644 index 0000000..70646f4 --- /dev/null +++ b/DepthConcat.lua @@ -0,0 +1,96 @@ +------------------------------------------------------------------------ +--[[ DepthConcat ]]-- +-- Concatenates the output of Convolutions along the depth dimension +-- (nOutputFrame). This is used to implement the DepthConcat layer +-- of the Going deeper with convolutions paper : +-- http://arxiv.org/pdf/1409.4842v1.pdf +-- The normal Concat Module can't be used since the spatial dimensions +-- of tensors to be concatenated may have different values. To deal with +-- this, we select the largest spatial dimensions and add zero-padding +-- around the smaller dimensions. +------------------------------------------------------------------------ +local DepthConcat, parent = torch.class('nn.DepthConcat', 'nn.Concat') + +function DepthConcat:windowNarrow(output, currentOutput, offset) + local outputWindow = output:narrow(self.dimension, offset, currentOutput:size(self.dimension)) + for dim=1,self.size:size(1) do + local currentSize = currentOutput:size(dim) + if dim ~= self.dimension and self.size[dim] ~= currentSize then + -- 5x5 vs 3x3 -> start = [(5-3)/2] + 1 = 2 (1 pad each side) + -- 9x9 vs 5x5 -> start = [(9-5)/2] + 1 = 3 (2 pad each side) + -- 9x9 vs 4x4 -> start = [(9-4)/2] + 1 = 3.5 (2 pad, 3 pad) + local start = math.floor(((self.size[dim] - currentSize) / 2) + 1) + outputWindow = outputWindow:narrow(dim, start, currentSize) + end + end + return outputWindow +end + +function DepthConcat:updateOutput(input) + local outs = {} + for i=1,#self.modules do + local currentOutput = self.modules[i]:updateOutput(input) + outs[i] = currentOutput + if i == 1 then + self.size:resize(currentOutput:dim()):copy(currentOutput:size()) + else + self.size[self.dimension] = self.size[self.dimension] + currentOutput:size(self.dimension) + for dim=1,self.size:size(1) do + if dim ~= self.dimension then + -- take the maximum size (shouldn't change anything for batch dim) + self.size[dim] = math.max(self.size[dim], currentOutput:size(dim)) + end + end + end + end + self.output:resize(self.size):zero() --zero for padding + + local offset = 1 + for i,module in ipairs(self.modules) do + local currentOutput = outs[i] + local outputWindow = self:windowNarrow(self.output, currentOutput, offset) + outputWindow:copy(currentOutput) + offset = offset + currentOutput:size(self.dimension) + end + return self.output +end + +function DepthConcat:updateGradInput(input, gradOutput) + self.gradInput:resizeAs(input) + + local offset = 1 + for i,module in ipairs(self.modules) do + local currentOutput = module.output + local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) + local currentGradInput = module:updateGradInput(input, gradOutputWindow) + if i==1 then + self.gradInput:copy(currentGradInput) + else + self.gradInput:add(currentGradInput) + end + offset = offset + currentOutput:size(self.dimension) + end + return self.gradInput +end + +function DepthConcat:accGradParameters(input, gradOutput, scale) + scale = scale or 1 + local offset = 1 + for i,module in ipairs(self.modules) do + local currentOutput = module.output + local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) + local currentGradInput = module:accGradParameters(input, gradOutputWindow, scale) + offset = offset + currentOutput:size(self.dimension) + end +end + +function DepthConcat:accUpdateGradParameters(input, gradOutput, lr) + local offset = 1 + for i,module in ipairs(self.modules) do + local currentOutput = module.output + local gradOutputWindow = self:windowNarrow(gradOutput, currentOutput, offset) + local currentGradInput = module:accUpdateGradParameters(input, gradOutputWindow, lr) + offset = offset + currentOutput:size(self.dimension) + end +end + |