diff options
author | Ronan Collobert <ronan@collobert.com> | 2012-01-25 17:55:20 +0400 |
---|---|---|
committer | Ronan Collobert <ronan@collobert.com> | 2012-01-25 17:55:20 +0400 |
commit | 4df3893abd1b9f840f1d9a8c1859799ccbf941de (patch) | |
tree | e8a1e1cc1b6ea6e47855347b157eaf419fdb357b /SpatialLPPooling.lua |
initial revamp of torch7 tree
Diffstat (limited to 'SpatialLPPooling.lua')
-rw-r--r-- | SpatialLPPooling.lua | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/SpatialLPPooling.lua b/SpatialLPPooling.lua new file mode 100644 index 0000000..9b9c87d --- /dev/null +++ b/SpatialLPPooling.lua @@ -0,0 +1,32 @@ +local SpatialLPPooling, parent = torch.class('nn.SpatialLPPooling', 'nn.Sequential') + +function SpatialLPPooling:__init(nInputPlane, pnorm, kW, kH, dW, dH) + parent.__init(self) + + dW = dW or kW + dH = dH or kH + + self.kW = kW + self.kH = kH + self.dW = dW + self.dH = dH + + self.nInputPlane = nInputPlane + self.learnKernel = learnKernel + + if pnorm == 2 then + self:add(nn.Square()) + else + self:add(nn.Power(pnorm)) + end + self:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(nInputPlane), kW, kH, dW, dH)) + if pnorm == 2 then + self:add(nn.Sqrt()) + else + self:add(nn.Power(1/pnorm)) + end + + self:get(2).bias:zero() + self:get(2).weight:fill(1/(kW*kH)) + self:get(2).accGradParameters = nil +end |