From b1bd980dc1ea466a640e4af87535fed1044c978c Mon Sep 17 00:00:00 2001 From: Clement Farabet Date: Mon, 22 Oct 2012 14:53:37 -0400 Subject: Using better optimized SpatialConvolutionMap, for Spatial*Normalization layer --- SpatialSubtractiveNormalization.lua | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) (limited to 'SpatialSubtractiveNormalization.lua') diff --git a/SpatialSubtractiveNormalization.lua b/SpatialSubtractiveNormalization.lua index dfa8fd2..f2c2c31 100644 --- a/SpatialSubtractiveNormalization.lua +++ b/SpatialSubtractiveNormalization.lua @@ -30,27 +30,23 @@ function SpatialSubtractiveNormalization:__init(nInputPlane, kernel) self.meanestimator = nn.Sequential() self.meanestimator:add(nn.SpatialZeroPadding(padW, padW, padH, padH)) if kdim == 2 then - self.meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(self.nInputPlane), - self.kernel:size(2), self.kernel:size(1))) + self.meanestimator:add(nn.SpatialConvolution(self.nInputPlane, 1, self.kernel:size(2), self.kernel:size(1))) else - self.meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(self.nInputPlane), - self.kernel:size(1), 1)) - self.meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(self.nInputPlane), - 1, self.kernel:size(1))) + self.meanestimator:add(nn.SpatialConvolutionMap(nn.tables.oneToOne(self.nInputPlane), self.kernel:size(1), 1)) + self.meanestimator:add(nn.SpatialConvolution(self.nInputPlane, 1, 1, self.kernel:size(1))) end - self.meanestimator:add(nn.Sum(1)) self.meanestimator:add(nn.Replicate(self.nInputPlane)) -- set kernel and bias if kdim == 2 then for i = 1,self.nInputPlane do - self.meanestimator.modules[2].weight[i] = self.kernel + self.meanestimator.modules[2].weight[1][i] = self.kernel end self.meanestimator.modules[2].bias:zero() else for i = 1,self.nInputPlane do self.meanestimator.modules[2].weight[i]:copy(self.kernel) - self.meanestimator.modules[3].weight[i]:copy(self.kernel) + self.meanestimator.modules[3].weight[1][i]:copy(self.kernel) end self.meanestimator.modules[2].bias:zero() self.meanestimator.modules[3].bias:zero() -- cgit v1.2.3