Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-08 05:34:17 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-08 05:34:17 +0400
commit60a6a1342e7ee26fc885fd50ff224e023ec3c46d (patch)
tree043420c4edefd4187521d83f2ee4924b87283b21 /SpatialNormalization.lua
parent0d3f8772d0d5872979af0685f40548ef919f684c (diff)
Added SpatialNormalization module. Still buggy with non-square kernels.
Diffstat (limited to 'SpatialNormalization.lua')
-rw-r--r--SpatialNormalization.lua300
1 files changed, 300 insertions, 0 deletions
diff --git a/SpatialNormalization.lua b/SpatialNormalization.lua
new file mode 100644
index 0000000..24ffc5e
--- /dev/null
+++ b/SpatialNormalization.lua
@@ -0,0 +1,300 @@
+local SpatialNormalization, parent = torch.class('nn.SpatialNormalization','nn.Module')
+
+local help_desc =
+[[a spatial (2D) contrast normalizer
+? computes the local mean and local std deviation
+ across all input features, using the given 2D kernel
+? the local mean is then removed from all maps, and the std dev
+ used to divide the inputs, with a threshold
+? if no threshold is given, the global std dev is used
+? weight replication is used to preserve sizes (this is
+ better than zero-padding, but more costly to compute, use
+ nn.ContrastNormalization to use zero-padding)
+? two 1D kernels can be used instead of a single 2D kernel. This
+ is beneficial to integrate information over large neiborhoods.
+]]
+
+local help_example =
+[[EX:
+-- create a spatial normalizer, with a 9x9 gaussian kernel
+-- works on 8 input feature maps, therefore the mean+dev will
+-- be estimated on 8x9x9 cubes
+stimulus = lab.randn(8,500,500)
+gaussian = image.gaussian(9)
+mod = nn.SpatialNormalization(gaussian, 8)
+result = mod:forward(stimulus)]]
+
+function SpatialNormalization:__init(...) -- kernel for weighted mean | nb of features
+ parent.__init(self)
+
+ -- get args
+ local args, nf, ker, thres, kers1D
+ = xlua.unpack(
+ {...},
+ 'nn.SpatialNormalization',
+ help_desc .. '\n' .. help_example,
+ {arg='nInputPlane', type='number', help='number of input maps', req=true},
+ {arg='kernel', type='torch.Tensor', help='a KxK filtering kernel'},
+ {arg='threshold', type='number', help='threshold, for division [default = adaptive]'},
+ {arg='kernels', type='table', help='two 1D filtering kernels (1xK and Kx1)'}
+ )
+
+ -- check args
+ if not ker and not kers1D then
+ xerror('please provide kernel(s)', 'nn.SpatialNormalization', args.usage)
+ end
+ self.kernel = ker or kers1D
+ local ker2
+ if kers1D then
+ ker = kers1D[1]
+ ker2 = kers1D[2]
+ end
+ self.nfeatures = nf
+ self.fixedThres = thres -- optional, if not provided, the global std is used
+
+ -- padding values
+ self.padW = math.floor(ker:size(2)/2)
+ self.padH = math.floor(ker:size(1)/2)
+ self.kerWisPair = 0
+ self.kerHisPair = 0
+
+ -- padding values for 2nd kernel
+ if ker2 then
+ self.pad2W = math.floor(ker2:size(2)/2)
+ self.pad2H = math.floor(ker2:size(1)/2)
+ else
+ self.pad2W = 0
+ self.pad2H = 0
+ end
+ self.ker2WisPair = 0
+ self.ker2HisPair = 0
+
+ -- normalize kernel
+ ker:div(ker:sum())
+ if ker2 then ker2:div(ker2:sum()) end
+
+ -- manage the case where ker is even size (for padding issue)
+ if (ker:size(2)/2 == math.floor(ker:size(2)/2)) then
+ print ('Warning, kernel width is even -> not symetric padding')
+ self.kerWisPair = 1
+ end
+ if (ker:size(1)/2 == math.floor(ker:size(1)/2)) then
+ print ('Warning, kernel height is even -> not symetric padding')
+ self.kerHisPair = 1
+ end
+ if (ker2 and ker2:size(2)/2 == math.floor(ker2:size(2)/2)) then
+ print ('Warning, kernel width is even -> not symetric padding')
+ self.ker2WisPair = 1
+ end
+ if (ker2 and ker2:size(1)/2 == math.floor(ker2:size(1)/2)) then
+ print ('Warning, kernel height is even -> not symetric padding')
+ self.ker2HisPair = 1
+ end
+
+ -- create convolution for computing the mean
+ local convo1 = nn.Sequential()
+ convo1:add(nn.SpatialPadding(self.padW,self.padW-self.kerWisPair,
+ self.padH,self.padH-self.kerHisPair))
+ local ctable = nn.tables.oneToOne(nf)
+ convo1:add(nn.SpatialConvolutionTable(ctable,ker:size(2),ker:size(1)))
+ convo1:add(nn.Sum(1))
+ convo1:add(nn.Replicate(nf))
+ -- set kernel
+ local fb = convo1.modules[2].weight
+ for i=1,fb:size(1) do fb[i]:copy(ker) end
+ -- set bias to 0
+ convo1.modules[2].bias:zero()
+
+ -- 2nd ker ?
+ if ker2 then
+ local convo2 = nn.Sequential()
+ convo2:add(nn.SpatialPadding(self.pad2W,self.pad2W-self.ker2WisPair,
+ self.pad2H,self.pad2H-self.ker2HisPair))
+ local ctable = nn.tables.oneToOne(nf)
+ convo2:add(nn.SpatialConvolutionTable(ctable,ker2:size(2),ker2:size(1)))
+ convo2:add(nn.Sum(1))
+ convo2:add(nn.Replicate(nf))
+ -- set kernel
+ local fb = convo2.modules[2].weight
+ for i=1,fb:size(1) do fb[i]:copy(ker2) end
+ -- set bias to 0
+ convo2.modules[2].bias:zero()
+ -- convo is a double convo now:
+ local convopack = nn.Sequential()
+ convopack:add(convo1)
+ convopack:add(convo2)
+ self.convo = convopack
+ else
+ self.convo = convo1
+ end
+
+ -- create convolution for computing the meanstd
+ local convostd1 = nn.Sequential()
+ convostd1:add(nn.SpatialPadding(self.padW,self.padW-self.kerWisPair,
+ self.padH,self.padH-self.kerHisPair))
+ convostd1:add(nn.SpatialConvolutionTable(ctable,ker:size(2),ker:size(1)))
+ convostd1:add(nn.Sum(1))
+ convostd1:add(nn.Replicate(nf))
+ -- set kernel
+ local fb = convostd1.modules[2].weight
+ for i=1,fb:size(1) do fb[i]:copy(ker) end
+ -- set bias to 0
+ convostd1.modules[2].bias:zero()
+
+ -- 2nd ker ?
+ if ker2 then
+ local convostd2 = nn.Sequential()
+ convostd2:add(nn.SpatialPadding(self.pad2W,self.pad2W-self.ker2WisPair,
+ self.pad2H,self.pad2H-self.ker2HisPair))
+ convostd2:add(nn.SpatialConvolutionTable(ctable,ker2:size(2),ker2:size(1)))
+ convostd2:add(nn.Sum(1))
+ convostd2:add(nn.Replicate(nf))
+ -- set kernel
+ local fb = convostd2.modules[2].weight
+ for i=1,fb:size(1) do fb[i]:copy(ker2) end
+ -- set bias to 0
+ convostd2.modules[2].bias:zero()
+ -- convo is a double convo now:
+ local convopack = nn.Sequential()
+ convopack:add(convostd1)
+ convopack:add(convostd2)
+ self.convostd = convopack
+ else
+ self.convostd = convostd1
+ end
+
+ -- other operation
+ self.squareMod = nn.Square()
+ self.sqrtMod = nn.Sqrt()
+ self.subtractMod = nn.CSubTable()
+ self.meanDiviseMod = nn.CDivTable()
+ self.stdDiviseMod = nn.CDivTable()
+ self.diviseMod = nn.CDivTable()
+ self.thresMod = nn.Threshold()
+ -- some tempo states
+ self.coef = torch.Tensor(1,1)
+ self.inConvo = torch.Tensor()
+ self.inMean = torch.Tensor()
+ self.inputZeroMean = torch.Tensor()
+ self.inputZeroMeanSq = torch.Tensor()
+ self.inConvoVar = torch.Tensor()
+ self.inVar = torch.Tensor()
+ self.inStdDev = torch.Tensor()
+ self.thstd = torch.Tensor()
+end
+
+function SpatialNormalization:forward(input)
+ -- auto switch to 3-channel
+ self.input = input
+ if (input:nDimension() == 2) then
+ self.input = torch.Tensor(1,input:size(1),input:size(2)):copy(input)
+ end
+
+ -- recompute coef only if necessary
+ if (self.input:size(3) ~= self.coef:size(2)) or (self.input:size(2) ~= self.coef:size(1)) then
+ local intVals = torch.Tensor(self.nfeatures,self.input:size(2),self.input:size(3)):fill(1)
+ self.coef = self.convo:forward(intVals)
+ self.coef = torch.Tensor():resizeAs(self.coef):copy(self.coef)
+ end
+
+ -- compute mean
+ self.inConvo = self.convo:forward(self.input)
+ self.inMean = self.meanDiviseMod:forward{self.inConvo,self.coef}
+ self.inputZeroMean = self.subtractMod:forward{self.input,self.inMean}
+
+ -- compute std dev
+ self.inputZeroMeanSq = self.squareMod:forward(self.inputZeroMean)
+ self.inConvoVar = self.convostd:forward(self.inputZeroMeanSq)
+ self.inStdDevNotUnit = self.sqrtMod:forward(self.inConvoVar)
+ self.inStdDev = self.stdDiviseMod:forward({self.inStdDevNotUnit,self.coef})
+ local meanstd = self.inStdDev:mean()
+ self.thresMod.threshold = self.fixedThres or math.max(meanstd,1e-3)
+ self.thresMod.val = self.fixedThres or math.max(meanstd,1e-3)
+ self.stdDev = self.thresMod:forward(self.inStdDev)
+
+ --remove std dev
+ self.diviseMod:forward{self.inputZeroMean,self.stdDev}
+ self.output = self.diviseMod.output
+ return self.output
+end
+
+function SpatialNormalization:backward(input, gradOutput)
+ -- auto switch to 3-channel
+ self.input = input
+ if (input:nDimension() == 2) then
+ self.input = torch.Tensor(1,input:size(1),input:size(2)):copy(input)
+ end
+ self.gradInput:resizeAs(self.input):zero()
+
+ -- backprop all
+ local gradDiv = self.diviseMod:backward({self.inputZeroMean,self.stdDev},gradOutput)
+ local gradThres = gradDiv[2]
+ local gradZeroMean = gradDiv[1]
+ local gradinStdDev = self.thresMod:backward(self.inStdDev,gradThres)
+ local gradstdDiv = self.stdDiviseMod:backward({self.inStdDevNotUnit,self.coef},gradinStdDev)
+ local gradinStdDevNotUnit = gradstdDiv[1]
+ local gradinConvoVar = self.sqrtMod:backward(self.inConvoVar,gradinStdDevNotUnit)
+ local gradinputZeroMeanSq = self.convostd:backward(self.inputZeroMeanSq,gradinConvoVar)
+ gradZeroMean:add(self.squareMod:backward(self.inputZeroMean,gradinputZeroMeanSq))
+ local gradDiff = self.subtractMod:backward({self.input,self.inMean},gradZeroMean)
+ local gradinMean = gradDiff[2]
+ local gradinConvoNotUnit = self.meanDiviseMod:backward({self.inConvo,self.coef},gradinMean)
+ local gradinConvo = gradinConvoNotUnit[1]
+ -- first part of the gradInput
+ self.gradInput:add(gradDiff[1])
+ -- second part of the gradInput
+ self.gradInput:add(self.convo:backward(self.input,gradinConvo))
+ return self.gradInput
+end
+
+function SpatialNormalization:write(file)
+ parent.write(self,file)
+ file:writeObject(self.kernel)
+ file:writeInt(self.nfeatures)
+ file:writeInt(self.padW)
+ file:writeInt(self.padH)
+ file:writeInt(self.kerWisPair)
+ file:writeInt(self.kerHisPair)
+ file:writeObject(self.convo)
+ file:writeObject(self.convostd)
+ file:writeObject(self.squareMod)
+ file:writeObject(self.sqrtMod)
+ file:writeObject(self.subtractMod)
+ file:writeObject(self.meanDiviseMod)
+ file:writeObject(self.stdDiviseMod)
+ file:writeObject(self.thresMod)
+ file:writeObject(self.diviseMod)
+ file:writeObject(self.coef)
+ if type(self.kernel) == 'table' then
+ file:writeInt(self.pad2W)
+ file:writeInt(self.pad2H)
+ file:writeInt(self.ker2WisPair)
+ file:writeInt(self.ker2HisPair)
+ end
+end
+
+function SpatialNormalization:read(file)
+ parent.read(self,file)
+ self.kernel = file:readObject()
+ self.nfeatures = file:readInt()
+ self.padW = file:readInt()
+ self.padH = file:readInt()
+ self.kerWisPair = file:readInt()
+ self.kerHisPair = file:readInt()
+ self.convo = file:readObject()
+ self.convostd = file:readObject()
+ self.squareMod = file:readObject()
+ self.sqrtMod = file:readObject()
+ self.subtractMod = file:readObject()
+ self.meanDiviseMod = file:readObject()
+ self.stdDiviseMod = file:readObject()
+ self.thresMod = file:readObject()
+ self.diviseMod = file:readObject()
+ self.coef = file:readObject()
+ if type(self.kernel) == 'table' then
+ self.pad2W = file:readInt()
+ self.pad2H = file:readInt()
+ self.ker2WisPair = file:readInt()
+ self.ker2HisPair = file:readInt()
+ end
+end