Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-01-30 19:54:41 +0400
committerRonan Collobert <ronan@collobert.com>2012-01-30 19:54:41 +0400
commit5468752e0a88537c3f27034cc7031e78a67f9202 (patch)
tree4df10ec8b98c720438d63cfb33cc88dc7aca1c4f
parent4df3893abd1b9f840f1d9a8c1859799ccbf941de (diff)
Merge branch 'master' into newpack
Conflicts: extra/nn/test/test.lua
-rw-r--r--Concat.lua29
-rw-r--r--Parallel.lua21
-rw-r--r--generic/SpatialMaxPooling.c4
-rw-r--r--test/test.lua22
4 files changed, 62 insertions, 14 deletions
diff --git a/Concat.lua b/Concat.lua
index 616c394..6543bcc 100644
--- a/Concat.lua
+++ b/Concat.lua
@@ -17,9 +17,10 @@ function Concat:get(index)
end
function Concat:updateOutput(input)
+ local outs = {}
for i=1,#self.modules do
local currentOutput = self.modules[i]:updateOutput(input)
-
+ outs[i] = currentOutput
if i == 1 then
self.size:resize(currentOutput:dim()):copy(currentOutput:size())
else
@@ -29,8 +30,9 @@ function Concat:updateOutput(input)
self.output:resize(self.size)
local offset = 1
- for _,module in ipairs(self.modules) do
- local currentOutput = module:updateOutput(input)
+ for i,module in ipairs(self.modules) do
+ --local currentOutput = module:updateOutput(input)
+ local currentOutput = outs[i]
self.output:narrow(self.dimension, offset, currentOutput:size(self.dimension)):copy(currentOutput)
offset = offset + currentOutput:size(self.dimension)
end
@@ -117,3 +119,24 @@ function Concat:parameters()
end
return w,gw
end
+
+function Concat:__tostring__()
+ local tab = ' '
+ local line = '\n'
+ local next = ' |`-> '
+ local ext = ' | '
+ local extlast = ' '
+ local last = ' ... -> '
+ local str = 'nn.Concat'
+ str = str .. ' {' .. line .. tab .. 'input'
+ for i=1,#self.modules do
+ if i == self.modules then
+ str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
+ else
+ str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
+ end
+ end
+ str = str .. line .. tab .. last .. 'output'
+ str = str .. line .. '}'
+ return str
+end
diff --git a/Parallel.lua b/Parallel.lua
index 04a8bdb..3c625bc 100644
--- a/Parallel.lua
+++ b/Parallel.lua
@@ -135,3 +135,24 @@ function Parallel:parameters()
end
return w,gw
end
+
+function Parallel:__tostring__()
+ local tab = ' '
+ local line = '\n'
+ local next = ' |`-> '
+ local ext = ' | '
+ local extlast = ' '
+ local last = ' ... -> '
+ local str = 'nn.Parallel'
+ str = str .. ' {' .. line .. tab .. 'input'
+ for i=1,#self.modules do
+ if i == self.modules then
+ str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
+ else
+ str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
+ end
+ end
+ str = str .. line .. tab .. last .. 'output'
+ str = str .. line .. '}'
+ return str
+end
diff --git a/generic/SpatialMaxPooling.c b/generic/SpatialMaxPooling.c
index b9fab3b..845181b 100644
--- a/generic/SpatialMaxPooling.c
+++ b/generic/SpatialMaxPooling.c
@@ -75,8 +75,8 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L)
*op = maxval;
// store location of max (x,y)
- *indyp = (int)(maxindex / dW)+1;
- *indxp = (maxindex % dW) +1;
+ *indyp = (int)(maxindex / kW)+1;
+ *indxp = (maxindex % kW) +1;
}
}
}
diff --git a/test/test.lua b/test/test.lua
index c18d3a2..3480829 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -812,15 +812,19 @@ function nntest.SpatialSubSampling()
end
function nntest.SpatialMaxPooling()
- local fanin = math.random(1,4)
- local osizex = math.random(1,20)
- local osizey = math.random(1,20)
- local mx = math.random(2,4)
- local my = math.random(2,4)
- local sizex = osizex*mx
- local sizey = osizey*my
- local module = nn.SpatialMaxPooling(mx,my,mx,my)
- local input = torch.rand(fanin,sizey,sizex)
+ local from = math.random(1,10)
+ local to = math.random(1,10)
+ local ki = math.random(1,10)
+ local kj = math.random(1,10)
+ local si = math.random(1,4)
+ local sj = math.random(1,4)
+ local outi = math.random(10,20)
+ local outj = math.random(10,20)
+ local ini = (outi-1)*si+ki
+ local inj = (outj-1)*sj+kj
+
+ local module = nn.SpatialMaxPooling(ki,kj,si,sj)
+ local input = lab.rand(from,ini,inj)
local err = jac.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state ')