diff options
author | Clement Farabet <clement.farabet@gmail.com> | 2012-09-05 23:21:08 +0400 |
---|---|---|
committer | Clement Farabet <clement.farabet@gmail.com> | 2012-09-05 23:21:08 +0400 |
commit | 7cd682f8dca26a2a39280b596bc8ab0206ee8f1d (patch) | |
tree | c0997bd3e0adf65494c60cb26cd8cd5fabdae1d9 | |
parent | 2fefe2aa3217aaf2118c81fbb648cea243bc4bed (diff) | |
parent | 3c36e6d5cc010a9e5d41aec92b8620a6dd3073b3 (diff) |
Merge branch 'master' of https://github.com/2ndforks/torch into cudaconv
-rw-r--r-- | SpatialConvolutionMap.lua | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/SpatialConvolutionMap.lua b/SpatialConvolutionMap.lua index 4b525ba..11718fd 100644 --- a/SpatialConvolutionMap.lua +++ b/SpatialConvolutionMap.lua @@ -54,6 +54,37 @@ function nn.tables.random(nin, nout, nto) return tbl end +function constructTableRev(conMatrix) + local conMatrixL = conMatrix:type('torch.LongTensor') + -- Construct reverse lookup connection table + local thickness = conMatrixL:select(2,2):max() + -- approximate fanin check + if (#conMatrixL)[1] % thickness == 0 then + -- do a proper fanin check and set revTable + local fanin = (#conMatrixL)[1] / thickness + local revTable = torch.Tensor(thickness, fanin, 2) + for ii=1,thickness do + local tempf = fanin + for jj=1,(#conMatrixL)[1] do + if conMatrixL[jj][2] == ii then + if tempf <= 0 then break end + revTable[ii][tempf][1] = conMatrixL[jj][1] + revTable[ii][tempf][2] = jj + tempf = tempf - 1 + end + end + if tempf ~= 0 then + fanin = -1 + break + end + end + if fanin ~= -1 then + return revTable + end + end + return {} +end + function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH) parent.__init(self) @@ -65,9 +96,9 @@ function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH) self.dW = dW self.dH = dH self.connTable = conMatrix + self.connTableRev = constructTableRev(conMatrix) self.nInputPlane = self.connTable:select(2,1):max() self.nOutputPlane = self.connTable:select(2,2):max() - self.weight = torch.Tensor(self.connTable:size(1), kH, kW) self.bias = torch.Tensor(self.nOutputPlane) self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW) |