Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonas Gehring <jgehring@fb.com>2016-08-17 00:14:07 +0300
committerJonas Gehring <jgehring@fb.com>2016-08-17 01:55:25 +0300
commitaa55dea2da46e71ba54c48bd525c3aae7fd0bc96 (patch)
treede062f224c55a09a242b04f6335ac27539d5b872 /DepthConcat.lua
parent8f6b12e7588e3e8e351bf1da4b609607df3af111 (diff)
Don't overwrite self.size in Concat
size() is a method of the Container base class.
Diffstat (limited to 'DepthConcat.lua')
-rw-r--r--DepthConcat.lua18
1 files changed, 10 insertions, 8 deletions
diff --git a/DepthConcat.lua b/DepthConcat.lua
index 8ae8384..f64a90e 100644
--- a/DepthConcat.lua
+++ b/DepthConcat.lua
@@ -13,13 +13,13 @@ local DepthConcat, _ = torch.class('nn.DepthConcat', 'nn.Concat')
function DepthConcat:windowNarrow(output, currentOutput, offset)
local outputWindow = output:narrow(self.dimension, offset, currentOutput:size(self.dimension))
- for dim=1,self.size:size(1) do
+ for dim=1,self.outputSize:size(1) do
local currentSize = currentOutput:size(dim)
- if dim ~= self.dimension and self.size[dim] ~= currentSize then
+ if dim ~= self.dimension and self.outputSize[dim] ~= currentSize then
-- 5x5 vs 3x3 -> start = [(5-3)/2] + 1 = 2 (1 pad each side)
-- 9x9 vs 5x5 -> start = [(9-5)/2] + 1 = 3 (2 pad each side)
-- 9x9 vs 4x4 -> start = [(9-4)/2] + 1 = 3.5 (2 pad, 3 pad)
- local start = math.floor(((self.size[dim] - currentSize) / 2) + 1)
+ local start = math.floor(((self.outputSize[dim] - currentSize) / 2) + 1)
outputWindow = outputWindow:narrow(dim, start, currentSize)
end
end
@@ -27,23 +27,25 @@ function DepthConcat:windowNarrow(output, currentOutput, offset)
end
function DepthConcat:updateOutput(input)
+ self.outputSize = self.outputSize or torch.LongStorage()
+
local outs = {}
for i=1,#self.modules do
local currentOutput = self:rethrowErrors(self.modules[i], i, 'updateOutput', input)
outs[i] = currentOutput
if i == 1 then
- self.size:resize(currentOutput:dim()):copy(currentOutput:size())
+ self.outputSize:resize(currentOutput:dim()):copy(currentOutput:size())
else
- self.size[self.dimension] = self.size[self.dimension] + currentOutput:size(self.dimension)
- for dim=1,self.size:size(1) do
+ self.outputSize[self.dimension] = self.outputSize[self.dimension] + currentOutput:size(self.dimension)
+ for dim=1,self.outputSize:size(1) do
if dim ~= self.dimension then
-- take the maximum size (shouldn't change anything for batch dim)
- self.size[dim] = math.max(self.size[dim], currentOutput:size(dim))
+ self.outputSize[dim] = math.max(self.outputSize[dim], currentOutput:size(dim))
end
end
end
end
- self.output:resize(self.size):zero() --zero for padding
+ self.output:resize(self.outputSize):zero() --zero for padding
local offset = 1
for i,module in ipairs(self.modules) do