diff options
author | Jonas Gehring <jgehring@fb.com> | 2016-08-17 00:14:07 +0300 |
---|---|---|
committer | Jonas Gehring <jgehring@fb.com> | 2016-08-17 01:55:25 +0300 |
commit | aa55dea2da46e71ba54c48bd525c3aae7fd0bc96 (patch) | |
tree | de062f224c55a09a242b04f6335ac27539d5b872 /DepthConcat.lua | |
parent | 8f6b12e7588e3e8e351bf1da4b609607df3af111 (diff) |
Don't overwrite self.size in Concat
size() is a method of the Container base class.
Diffstat (limited to 'DepthConcat.lua')
-rw-r--r-- | DepthConcat.lua | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/DepthConcat.lua b/DepthConcat.lua index 8ae8384..f64a90e 100644 --- a/DepthConcat.lua +++ b/DepthConcat.lua @@ -13,13 +13,13 @@ local DepthConcat, _ = torch.class('nn.DepthConcat', 'nn.Concat') function DepthConcat:windowNarrow(output, currentOutput, offset) local outputWindow = output:narrow(self.dimension, offset, currentOutput:size(self.dimension)) - for dim=1,self.size:size(1) do + for dim=1,self.outputSize:size(1) do local currentSize = currentOutput:size(dim) - if dim ~= self.dimension and self.size[dim] ~= currentSize then + if dim ~= self.dimension and self.outputSize[dim] ~= currentSize then -- 5x5 vs 3x3 -> start = [(5-3)/2] + 1 = 2 (1 pad each side) -- 9x9 vs 5x5 -> start = [(9-5)/2] + 1 = 3 (2 pad each side) -- 9x9 vs 4x4 -> start = [(9-4)/2] + 1 = 3.5 (2 pad, 3 pad) - local start = math.floor(((self.size[dim] - currentSize) / 2) + 1) + local start = math.floor(((self.outputSize[dim] - currentSize) / 2) + 1) outputWindow = outputWindow:narrow(dim, start, currentSize) end end @@ -27,23 +27,25 @@ function DepthConcat:windowNarrow(output, currentOutput, offset) end function DepthConcat:updateOutput(input) + self.outputSize = self.outputSize or torch.LongStorage() + local outs = {} for i=1,#self.modules do local currentOutput = self:rethrowErrors(self.modules[i], i, 'updateOutput', input) outs[i] = currentOutput if i == 1 then - self.size:resize(currentOutput:dim()):copy(currentOutput:size()) + self.outputSize:resize(currentOutput:dim()):copy(currentOutput:size()) else - self.size[self.dimension] = self.size[self.dimension] + currentOutput:size(self.dimension) - for dim=1,self.size:size(1) do + self.outputSize[self.dimension] = self.outputSize[self.dimension] + currentOutput:size(self.dimension) + for dim=1,self.outputSize:size(1) do if dim ~= self.dimension then -- take the maximum size (shouldn't change anything for batch dim) - self.size[dim] = math.max(self.size[dim], currentOutput:size(dim)) + self.outputSize[dim] = math.max(self.outputSize[dim], currentOutput:size(dim)) end end end end - self.output:resize(self.size):zero() --zero for padding + self.output:resize(self.outputSize):zero() --zero for padding local offset = 1 for i,module in ipairs(self.modules) do |