diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-04-05 00:40:47 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-04-05 00:40:47 +0400 |
commit | 14784d8b68a68a352fd70efa7dd77c039049066a (patch) | |
tree | 8e991c5d0dd61e0f16eafdbb4afb8f262b8c17c7 /SpatialContrastiveNormalization.lua | |
parent | 4e94a75b4193e004f50d7e8fe2808fb8ebb2375f (diff) |
Fixed normalization for 1D kernels, and added Contrastive Normalization
Diffstat (limited to 'SpatialContrastiveNormalization.lua')
-rw-r--r-- | SpatialContrastiveNormalization.lua | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/SpatialContrastiveNormalization.lua b/SpatialContrastiveNormalization.lua new file mode 100644 index 0000000..262d3b1 --- /dev/null +++ b/SpatialContrastiveNormalization.lua @@ -0,0 +1,42 @@ +local SpatialContrastiveNormalization, parent = torch.class('nn.SpatialContrastiveNormalization','nn.Module') + +function SpatialContrastiveNormalization:__init(nInputPlane, kernel, threshold, thresval) + parent.__init(self) + + -- get args + self.nInputPlane = nInputPlane or 1 + self.kernel = kernel or torch.Tensor(9,9):fill(1) + self.threshold = threshold or 1e-4 + self.thresval = thresval or 1e-4 + local kdim = self.kernel:nDimension() + + -- check args + if kdim ~= 2 and kdim ~= 1 then + error('<SpatialContrastiveNormalization> averaging kernel must be 2D or 1D') + end + if (self.kernel:size(1) % 2) == 0 or (kdim == 2 and (self.kernel:size(2) % 2) == 0) then + error('<SpatialContrastiveNormalization> averaging kernel must have ODD dimensions') + end + + -- instantiate sub+div normalization + self.normalizer = nn.Sequential() + self.normalizer:add(nn.SpatialSubtractiveNormalization(self.nInputPlane, self.kernel)) + self.normalizer:add(nn.SpatialDivisiveNormalization(self.nInputPlane, self.kernel, + self.threshold, self.threshval)) +end + +function SpatialContrastiveNormalization:updateOutput(input) + self.output = self.normalizer:forward(input) + return self.output +end + +function SpatialContrastiveNormalization:updateGradInput(input, gradOutput) + self.gradInput = self.normalizer:backward(input, gradOutput) + return self.gradInput +end + +function SpatialContrastiveNormalization:type(type) + parent.type(self,type) + self.normalizer:type(type) + return self +end |