Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
committerRonan Collobert <ronan@collobert.com>2012-01-25 17:55:20 +0400
commit4df3893abd1b9f840f1d9a8c1859799ccbf941de (patch)
treee8a1e1cc1b6ea6e47855347b157eaf419fdb357b /SpatialSubtractiveNormalization.lua
initial revamp of torch7 tree
Diffstat (limited to 'SpatialSubtractiveNormalization.lua')
-rw-r--r--SpatialSubtractiveNormalization.lua104
1 files changed, 104 insertions, 0 deletions
diff --git a/SpatialSubtractiveNormalization.lua b/SpatialSubtractiveNormalization.lua
new file mode 100644
index 0000000..4df0fc1
--- /dev/null
+++ b/SpatialSubtractiveNormalization.lua
@@ -0,0 +1,104 @@
+local SpatialSubtractiveNormalization, parent = torch.class('nn.SpatialSubtractiveNormalization','nn.Module')
+
+function SpatialSubtractiveNormalization:__init(nInputPlane, kernel)
+ parent.__init(self)
+
+ -- get args
+ self.nInputPlane = nInputPlane or 1
+ self.kernel = kernel or torch.Tensor(9,9):fill(1)
+ local kdim = self.kernel:nDimension()
+
+ -- check args
+ if kdim ~= 2 and kdim ~= 1 then
+ error('<SpatialSubtractiveNormalization> averaging kernel must be 2D or 1D')
+ end
+ if (self.kernel:size(1) % 2) == 0 or (kdim == 2 and (self.kernel:size(2) % 2) == 0) then
+ error('<SpatialSubtractiveNormalization> averaging kernel must have ODD dimensions')
+ end
+
+ -- normalize kernel
+ self.kernel:div(self.kernel:sumall() * self.nInputPlane)
+
+ -- padding values
+ local padH = math.floor(self.kernel:size(1)/2)
+ local padW = padH
+ if kdim == 2 then
+ padW = math.floor(self.kernel:size(2)/2)
+ end
+
+ -- create convolutional mean extractor
+ self.meanestimator = nn.Sequential()
+ self.meanestimator:add(nn.SpatialZeroPadding(padW, padW, padH, padH))
+ if kdim == 2 then
+ self.meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(self.nInputPlane),
+ self.kernel:size(2), self.kernel:size(1)))
+ else
+ self.meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(self.nInputPlane),
+ self.kernel:size(1), 1))
+ self.meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(self.nInputPlane),
+ 1, self.kernel:size(1)))
+ end
+ self.meanestimator:add(nn.Sum(1))
+ self.meanestimator:add(nn.Replicate(self.nInputPlane))
+
+ -- set kernel and bias
+ if kdim == 2 then
+ for i = 1,self.nInputPlane do
+ self.meanestimator.modules[2].weight[i] = self.kernel
+ end
+ self.meanestimator.modules[2].bias:zero()
+ else
+ for i = 1,self.nInputPlane do
+ self.meanestimator.modules[2].weight[i]:copy(self.kernel)
+ self.meanestimator.modules[3].weight[i]:copy(self.kernel)
+ end
+ self.meanestimator.modules[2].bias:zero()
+ self.meanestimator.modules[3].bias:zero()
+ end
+
+ -- other operation
+ self.subtractor = nn.CSubTable()
+ self.divider = nn.CDivTable()
+
+ -- coefficient array, to adjust side effects
+ self.coef = torch.Tensor(1,1,1)
+end
+
+function SpatialSubtractiveNormalization:updateOutput(input)
+ -- compute side coefficients
+ if (input:size(3) ~= self.coef:size(2)) or (input:size(2) ~= self.coef:size(1)) then
+ local ones = input.new():resizeAs(input):fill(1)
+ self.coef = self.meanestimator:updateOutput(ones)
+ self.coef = self.coef:clone()
+ end
+
+ -- compute mean
+ self.localsums = self.meanestimator:updateOutput(input)
+ self.adjustedsums = self.divider:updateOutput{self.localsums, self.coef}
+ self.output = self.subtractor:updateOutput{input, self.adjustedsums}
+
+ -- done
+ return self.output
+end
+
+function SpatialSubtractiveNormalization:updateGradInput(input, gradOutput)
+ -- resize grad
+ self.gradInput:resizeAs(input):zero()
+
+ -- backprop through all modules
+ local gradsub = self.subtractor:updateGradInput({input, self.adjustedsums}, gradOutput)
+ local graddiv = self.divider:updateGradInput({self.localsums, self.coef}, gradsub[2])
+ self.gradInput:add(self.meanestimator:updateGradInput(input, graddiv[1]))
+ self.gradInput:add(gradsub[1])
+
+ -- done
+ return self.gradInput
+end
+
+function SpatialSubtractiveNormalization:type(type)
+ parent.type(self,type)
+ self.meanestimator:type(type)
+ self.divider:type(type)
+ self.subtractor:type(type)
+ return self
+end