diff options
author | Ronan Collobert <ronan@collobert.com> | 2012-01-30 19:54:41 +0400 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2012-01-30 19:54:41 +0400 |
commit | 5468752e0a88537c3f27034cc7031e78a67f9202 (patch) | |
tree | 4df10ec8b98c720438d63cfb33cc88dc7aca1c4f | |
parent | 4df3893abd1b9f840f1d9a8c1859799ccbf941de (diff) |
Merge branch 'master' into newpack
Conflicts:
extra/nn/test/test.lua
-rw-r--r-- | Concat.lua | 29 | ||||
-rw-r--r-- | Parallel.lua | 21 | ||||
-rw-r--r-- | generic/SpatialMaxPooling.c | 4 | ||||
-rw-r--r-- | test/test.lua | 22 |
4 files changed, 62 insertions, 14 deletions
@@ -17,9 +17,10 @@ function Concat:get(index) end function Concat:updateOutput(input) + local outs = {} for i=1,#self.modules do local currentOutput = self.modules[i]:updateOutput(input) - + outs[i] = currentOutput if i == 1 then self.size:resize(currentOutput:dim()):copy(currentOutput:size()) else @@ -29,8 +30,9 @@ function Concat:updateOutput(input) self.output:resize(self.size) local offset = 1 - for _,module in ipairs(self.modules) do - local currentOutput = module:updateOutput(input) + for i,module in ipairs(self.modules) do + --local currentOutput = module:updateOutput(input) + local currentOutput = outs[i] self.output:narrow(self.dimension, offset, currentOutput:size(self.dimension)):copy(currentOutput) offset = offset + currentOutput:size(self.dimension) end @@ -117,3 +119,24 @@ function Concat:parameters() end return w,gw end + +function Concat:__tostring__() + local tab = ' ' + local line = '\n' + local next = ' |`-> ' + local ext = ' | ' + local extlast = ' ' + local last = ' ... -> ' + local str = 'nn.Concat' + str = str .. ' {' .. line .. tab .. 'input' + for i=1,#self.modules do + if i == self.modules then + str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) + else + str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) + end + end + str = str .. line .. tab .. last .. 'output' + str = str .. line .. '}' + return str +end diff --git a/Parallel.lua b/Parallel.lua index 04a8bdb..3c625bc 100644 --- a/Parallel.lua +++ b/Parallel.lua @@ -135,3 +135,24 @@ function Parallel:parameters() end return w,gw end + +function Parallel:__tostring__() + local tab = ' ' + local line = '\n' + local next = ' |`-> ' + local ext = ' | ' + local extlast = ' ' + local last = ' ... -> ' + local str = 'nn.Parallel' + str = str .. ' {' .. line .. tab .. 'input' + for i=1,#self.modules do + if i == self.modules then + str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast) + else + str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext) + end + end + str = str .. line .. tab .. last .. 'output' + str = str .. line .. '}' + return str +end diff --git a/generic/SpatialMaxPooling.c b/generic/SpatialMaxPooling.c index b9fab3b..845181b 100644 --- a/generic/SpatialMaxPooling.c +++ b/generic/SpatialMaxPooling.c @@ -75,8 +75,8 @@ static int nn_(SpatialMaxPooling_updateOutput)(lua_State *L) *op = maxval; // store location of max (x,y) - *indyp = (int)(maxindex / dW)+1; - *indxp = (maxindex % dW) +1; + *indyp = (int)(maxindex / kW)+1; + *indxp = (maxindex % kW) +1; } } } diff --git a/test/test.lua b/test/test.lua index c18d3a2..3480829 100644 --- a/test/test.lua +++ b/test/test.lua @@ -812,15 +812,19 @@ function nntest.SpatialSubSampling() end function nntest.SpatialMaxPooling() - local fanin = math.random(1,4) - local osizex = math.random(1,20) - local osizey = math.random(1,20) - local mx = math.random(2,4) - local my = math.random(2,4) - local sizex = osizex*mx - local sizey = osizey*my - local module = nn.SpatialMaxPooling(mx,my,mx,my) - local input = torch.rand(fanin,sizey,sizex) + local from = math.random(1,10) + local to = math.random(1,10) + local ki = math.random(1,10) + local kj = math.random(1,10) + local si = math.random(1,4) + local sj = math.random(1,4) + local outi = math.random(10,20) + local outj = math.random(10,20) + local ini = (outi-1)*si+ki + local inj = (outj-1)*sj+kj + + local module = nn.SpatialMaxPooling(ki,kj,si,sj) + local input = lab.rand(from,ini,inj) local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') |