Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClement Farabet <clement.farabet@gmail.com>2012-09-05 23:21:08 +0400
committerClement Farabet <clement.farabet@gmail.com>2012-09-05 23:21:08 +0400
commit7cd682f8dca26a2a39280b596bc8ab0206ee8f1d (patch)
treec0997bd3e0adf65494c60cb26cd8cd5fabdae1d9
parent2fefe2aa3217aaf2118c81fbb648cea243bc4bed (diff)
parent3c36e6d5cc010a9e5d41aec92b8620a6dd3073b3 (diff)
Merge branch 'master' of https://github.com/2ndforks/torch into cudaconv
-rw-r--r--SpatialConvolutionMap.lua33
1 files changed, 32 insertions, 1 deletions
diff --git a/SpatialConvolutionMap.lua b/SpatialConvolutionMap.lua
index 4b525ba..11718fd 100644
--- a/SpatialConvolutionMap.lua
+++ b/SpatialConvolutionMap.lua
@@ -54,6 +54,37 @@ function nn.tables.random(nin, nout, nto)
return tbl
end
+function constructTableRev(conMatrix)
+ local conMatrixL = conMatrix:type('torch.LongTensor')
+ -- Construct reverse lookup connection table
+ local thickness = conMatrixL:select(2,2):max()
+ -- approximate fanin check
+ if (#conMatrixL)[1] % thickness == 0 then
+ -- do a proper fanin check and set revTable
+ local fanin = (#conMatrixL)[1] / thickness
+ local revTable = torch.Tensor(thickness, fanin, 2)
+ for ii=1,thickness do
+ local tempf = fanin
+ for jj=1,(#conMatrixL)[1] do
+ if conMatrixL[jj][2] == ii then
+ if tempf <= 0 then break end
+ revTable[ii][tempf][1] = conMatrixL[jj][1]
+ revTable[ii][tempf][2] = jj
+ tempf = tempf - 1
+ end
+ end
+ if tempf ~= 0 then
+ fanin = -1
+ break
+ end
+ end
+ if fanin ~= -1 then
+ return revTable
+ end
+ end
+ return {}
+end
+
function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH)
parent.__init(self)
@@ -65,9 +96,9 @@ function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH)
self.dW = dW
self.dH = dH
self.connTable = conMatrix
+ self.connTableRev = constructTableRev(conMatrix)
self.nInputPlane = self.connTable:select(2,1):max()
self.nOutputPlane = self.connTable:select(2,2):max()
-
self.weight = torch.Tensor(self.connTable:size(1), kH, kW)
self.bias = torch.Tensor(self.nOutputPlane)
self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW)