Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSergey Zagoruyko <zagoruyko2@gmail.com>2015-03-04 17:39:05 +0300
committerSergey Zagoruyko <zagoruyko2@gmail.com>2015-03-04 17:39:05 +0300
commitcb303329bd82f1bb80d3c7eebc14d41195acec83 (patch)
treecf3fc5fcafe6392e5c11e657c3809e9e070de855
parent7470cc81d458066d751156df52d05d3482411e6c (diff)
SpatialAveragePooling divides by kW*kH
-rw-r--r--SpatialAveragePooling.lua16
-rw-r--r--SpatialLPPooling.lua1
-rw-r--r--generic/SpatialAveragePooling.c4
-rw-r--r--test.lua4
4 files changed, 19 insertions, 6 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua
index 13b6b45..90b79aa 100644
--- a/SpatialAveragePooling.lua
+++ b/SpatialAveragePooling.lua
@@ -7,14 +7,26 @@ function SpatialAveragePooling:__init(kW, kH, dW, dH)
self.kH = kH
self.dW = dW or 1
self.dH = dH or 1
+ self.divide = true
end
function SpatialAveragePooling:updateOutput(input)
- return input.nn.SpatialAveragePooling_updateOutput(self, input)
+ input.nn.SpatialAveragePooling_updateOutput(self, input)
+ -- for backward compatibility with saved models
+ -- which are not supposed to have "divide" field
+ if not self.divide then
+ self.output:mul(self.kW*self.kH)
+ end
+ return self.output
end
function SpatialAveragePooling:updateGradInput(input, gradOutput)
if self.gradInput then
- return input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput)
+ input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput)
+ -- for backward compatibility
+ if not self.divide then
+ self.gradInput:mul(self.kW*self.kH)
+ end
+ return self.gradInput
end
end
diff --git a/SpatialLPPooling.lua b/SpatialLPPooling.lua
index 4e6fe74..fc56296 100644
--- a/SpatialLPPooling.lua
+++ b/SpatialLPPooling.lua
@@ -17,6 +17,7 @@ function SpatialLPPooling:__init(nInputPlane, pnorm, kW, kH, dW, dH)
self:add(nn.Power(pnorm))
end
self:add(nn.SpatialAveragePooling(kW, kH, dW, dH))
+ self:add(nn.MulConstant(kW*kH))
if pnorm == 2 then
self:add(nn.Sqrt())
else
diff --git a/generic/SpatialAveragePooling.c b/generic/SpatialAveragePooling.c
index 2052d05..681938e 100644
--- a/generic/SpatialAveragePooling.c
+++ b/generic/SpatialAveragePooling.c
@@ -83,7 +83,7 @@ static int nn_(SpatialAveragePooling_updateOutput)(lua_State *L)
ptr_input += inputWidth; /* next input line */
}
/* Update output */
- *ptr_output++ += sum;
+ *ptr_output++ += sum/(kW*kH);
}
}
}
@@ -163,7 +163,7 @@ static int nn_(SpatialAveragePooling_updateGradInput)(lua_State *L)
for(ky = 0; ky < kH; ky++)
{
for(kx = 0; kx < kW; kx++)
- ptr_gradInput[kx] += z;
+ ptr_gradInput[kx] += z/(kW*kH);
ptr_gradInput += inputWidth;
}
}
diff --git a/test.lua b/test.lua
index f1576fc..8f5fbaa 100644
--- a/test.lua
+++ b/test.lua
@@ -1706,7 +1706,7 @@ function nntest.SpatialAveragePooling()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
local sap = nn.SpatialSubSampling(from, ki, kj, si, sj)
- sap.weight:fill(1.0)
+ sap.weight:fill(1.0/(ki*kj))
sap.bias:fill(0.0)
local output = module:forward(input)
@@ -1737,7 +1737,7 @@ function nntest.SpatialAveragePooling()
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ')
local sap = nn.SpatialSubSampling(from, ki, kj, si, sj)
- sap.weight:fill(1.0)
+ sap.weight:fill(1.0/(ki*kj))
sap.bias:fill(0.0)
local output = module:forward(input)