diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2011-07-08 07:47:26 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2011-07-08 07:47:26 +0400 |
commit | 4996633b6ed7b775ebfe5ba0704a48a0219a5c40 (patch) | |
tree | 36c169175b33c8a4726821281e4a661d94a7babc /SpatialNormalization.lua | |
parent | 60a6a1342e7ee26fc885fd50ff224e023ec3c46d (diff) |
using torch.Tensor()
Diffstat (limited to 'SpatialNormalization.lua')
-rw-r--r-- | SpatialNormalization.lua | 24 |
1 files changed, 13 insertions, 11 deletions
diff --git a/SpatialNormalization.lua b/SpatialNormalization.lua index 24ffc5e..5693cb2 100644 --- a/SpatialNormalization.lua +++ b/SpatialNormalization.lua @@ -28,29 +28,28 @@ function SpatialNormalization:__init(...) -- kernel for weighted mean | nb of fe parent.__init(self) -- get args - local args, nf, ker, thres, kers1D + local args, nf, ker, thres = xlua.unpack( {...}, 'nn.SpatialNormalization', help_desc .. '\n' .. help_example, {arg='nInputPlane', type='number', help='number of input maps', req=true}, - {arg='kernel', type='torch.Tensor', help='a KxK filtering kernel'}, - {arg='threshold', type='number', help='threshold, for division [default = adaptive]'}, - {arg='kernels', type='table', help='two 1D filtering kernels (1xK and Kx1)'} + {arg='kernel', type='torch.Tensor | table', help='a KxK filtering kernel or two {1xK, Kx1} 1D kernels'}, + {arg='threshold', type='number', help='threshold, for division [default = adaptive]'} ) -- check args - if not ker and not kers1D then + if not ker then xerror('please provide kernel(s)', 'nn.SpatialNormalization', args.usage) end - self.kernel = ker or kers1D + self.kernel = ker local ker2 - if kers1D then - ker = kers1D[1] - ker2 = kers1D[2] + if type(ker) == 'table' then + ker2 = ker[2] + ker = ker[1] end self.nfeatures = nf - self.fixedThres = thres -- optional, if not provided, the global std is used + self.fixedThres = thres -- padding values self.padW = math.floor(ker:size(2)/2) @@ -211,7 +210,7 @@ function SpatialNormalization:forward(input) self.thresMod.threshold = self.fixedThres or math.max(meanstd,1e-3) self.thresMod.val = self.fixedThres or math.max(meanstd,1e-3) self.stdDev = self.thresMod:forward(self.inStdDev) - + --remove std dev self.diviseMod:forward{self.inputZeroMean,self.stdDev} self.output = self.diviseMod.output @@ -271,6 +270,7 @@ function SpatialNormalization:write(file) file:writeInt(self.ker2WisPair) file:writeInt(self.ker2HisPair) end + file:writeInt(self.fixedThres or 0) end function SpatialNormalization:read(file) @@ -297,4 +297,6 @@ function SpatialNormalization:read(file) self.ker2WisPair = file:readInt() self.ker2HisPair = file:readInt() end + self.fixedThres = file:readInt() + if self.fixedThres == 0 then self.fixedThres = nil end end |