diff options
-rw-r--r-- | SpatialAveragePooling.lua | 16 | ||||
-rw-r--r-- | SpatialLPPooling.lua | 1 | ||||
-rw-r--r-- | generic/SpatialAveragePooling.c | 4 | ||||
-rw-r--r-- | test.lua | 4 |
4 files changed, 19 insertions, 6 deletions
diff --git a/SpatialAveragePooling.lua b/SpatialAveragePooling.lua index 13b6b45..90b79aa 100644 --- a/SpatialAveragePooling.lua +++ b/SpatialAveragePooling.lua @@ -7,14 +7,26 @@ function SpatialAveragePooling:__init(kW, kH, dW, dH) self.kH = kH self.dW = dW or 1 self.dH = dH or 1 + self.divide = true end function SpatialAveragePooling:updateOutput(input) - return input.nn.SpatialAveragePooling_updateOutput(self, input) + input.nn.SpatialAveragePooling_updateOutput(self, input) + -- for backward compatibility with saved models + -- which are not supposed to have "divide" field + if not self.divide then + self.output:mul(self.kW*self.kH) + end + return self.output end function SpatialAveragePooling:updateGradInput(input, gradOutput) if self.gradInput then - return input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput) + input.nn.SpatialAveragePooling_updateGradInput(self, input, gradOutput) + -- for backward compatibility + if not self.divide then + self.gradInput:mul(self.kW*self.kH) + end + return self.gradInput end end diff --git a/SpatialLPPooling.lua b/SpatialLPPooling.lua index 4e6fe74..fc56296 100644 --- a/SpatialLPPooling.lua +++ b/SpatialLPPooling.lua @@ -17,6 +17,7 @@ function SpatialLPPooling:__init(nInputPlane, pnorm, kW, kH, dW, dH) self:add(nn.Power(pnorm)) end self:add(nn.SpatialAveragePooling(kW, kH, dW, dH)) + self:add(nn.MulConstant(kW*kH)) if pnorm == 2 then self:add(nn.Sqrt()) else diff --git a/generic/SpatialAveragePooling.c b/generic/SpatialAveragePooling.c index 2052d05..681938e 100644 --- a/generic/SpatialAveragePooling.c +++ b/generic/SpatialAveragePooling.c @@ -83,7 +83,7 @@ static int nn_(SpatialAveragePooling_updateOutput)(lua_State *L) ptr_input += inputWidth; /* next input line */ } /* Update output */ - *ptr_output++ += sum; + *ptr_output++ += sum/(kW*kH); } } } @@ -163,7 +163,7 @@ static int nn_(SpatialAveragePooling_updateGradInput)(lua_State *L) for(ky = 0; ky < kH; ky++) { for(kx = 0; kx < kW; kx++) - ptr_gradInput[kx] += z; + ptr_gradInput[kx] += z/(kW*kH); ptr_gradInput += inputWidth; } } @@ -1706,7 +1706,7 @@ function nntest.SpatialAveragePooling() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') local sap = nn.SpatialSubSampling(from, ki, kj, si, sj) - sap.weight:fill(1.0) + sap.weight:fill(1.0/(ki*kj)) sap.bias:fill(0.0) local output = module:forward(input) @@ -1737,7 +1737,7 @@ function nntest.SpatialAveragePooling() mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err (Batch) ') local sap = nn.SpatialSubSampling(from, ki, kj, si, sj) - sap.weight:fill(1.0) + sap.weight:fill(1.0/(ki*kj)) sap.bias:fill(0.0) local output = module:forward(input) |