Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2011-07-08 07:47:26 +0400
committerClement Farabet <clement.farabet@gmail.com>2011-07-08 07:47:26 +0400
commit4996633b6ed7b775ebfe5ba0704a48a0219a5c40 (patch)
tree36c169175b33c8a4726821281e4a661d94a7babc /SpatialNormalization.lua
parent60a6a1342e7ee26fc885fd50ff224e023ec3c46d (diff)
using torch.Tensor()
Diffstat (limited to 'SpatialNormalization.lua')
-rw-r--r--SpatialNormalization.lua24
1 files changed, 13 insertions, 11 deletions
diff --git a/SpatialNormalization.lua b/SpatialNormalization.lua
index 24ffc5e..5693cb2 100644
--- a/SpatialNormalization.lua
+++ b/SpatialNormalization.lua
@@ -28,29 +28,28 @@ function SpatialNormalization:__init(...) -- kernel for weighted mean | nb of fe
parent.__init(self)
-- get args
- local args, nf, ker, thres, kers1D
+ local args, nf, ker, thres
= xlua.unpack(
{...},
'nn.SpatialNormalization',
help_desc .. '\n' .. help_example,
{arg='nInputPlane', type='number', help='number of input maps', req=true},
- {arg='kernel', type='torch.Tensor', help='a KxK filtering kernel'},
- {arg='threshold', type='number', help='threshold, for division [default = adaptive]'},
- {arg='kernels', type='table', help='two 1D filtering kernels (1xK and Kx1)'}
+ {arg='kernel', type='torch.Tensor | table', help='a KxK filtering kernel or two {1xK, Kx1} 1D kernels'},
+ {arg='threshold', type='number', help='threshold, for division [default = adaptive]'}
)
-- check args
- if not ker and not kers1D then
+ if not ker then
xerror('please provide kernel(s)', 'nn.SpatialNormalization', args.usage)
end
- self.kernel = ker or kers1D
+ self.kernel = ker
local ker2
- if kers1D then
- ker = kers1D[1]
- ker2 = kers1D[2]
+ if type(ker) == 'table' then
+ ker2 = ker[2]
+ ker = ker[1]
end
self.nfeatures = nf
- self.fixedThres = thres -- optional, if not provided, the global std is used
+ self.fixedThres = thres
-- padding values
self.padW = math.floor(ker:size(2)/2)
@@ -211,7 +210,7 @@ function SpatialNormalization:forward(input)
self.thresMod.threshold = self.fixedThres or math.max(meanstd,1e-3)
self.thresMod.val = self.fixedThres or math.max(meanstd,1e-3)
self.stdDev = self.thresMod:forward(self.inStdDev)
-
+
--remove std dev
self.diviseMod:forward{self.inputZeroMean,self.stdDev}
self.output = self.diviseMod.output
@@ -271,6 +270,7 @@ function SpatialNormalization:write(file)
file:writeInt(self.ker2WisPair)
file:writeInt(self.ker2HisPair)
end
+ file:writeInt(self.fixedThres or 0)
end
function SpatialNormalization:read(file)
@@ -297,4 +297,6 @@ function SpatialNormalization:read(file)
self.ker2WisPair = file:readInt()
self.ker2HisPair = file:readInt()
end
+ self.fixedThres = file:readInt()
+ if self.fixedThres == 0 then self.fixedThres = nil end
end