Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2012-08-24 03:03:40 +0400
committerSoumith Chintala <soumith@gmail.com>2012-08-24 03:03:40 +0400
commit3c36e6d5cc010a9e5d41aec92b8620a6dd3073b3 (patch)
treed06102a402465be4df452bcffaa3a89382788e2f
parent84c3611ffbd7f60b9c10b2bd8e3dc951f1944992 (diff)
added cuda convolutionMap
-rw-r--r--SpatialConvolutionMap.lua33
1 files changed, 32 insertions, 1 deletions
diff --git a/SpatialConvolutionMap.lua b/SpatialConvolutionMap.lua
index 4b525ba..11718fd 100644
--- a/SpatialConvolutionMap.lua
+++ b/SpatialConvolutionMap.lua
@@ -54,6 +54,37 @@ function nn.tables.random(nin, nout, nto)
return tbl
end
+function constructTableRev(conMatrix)
+ local conMatrixL = conMatrix:type('torch.LongTensor')
+ -- Construct reverse lookup connection table
+ local thickness = conMatrixL:select(2,2):max()
+ -- approximate fanin check
+ if (#conMatrixL)[1] % thickness == 0 then
+ -- do a proper fanin check and set revTable
+ local fanin = (#conMatrixL)[1] / thickness
+ local revTable = torch.Tensor(thickness, fanin, 2)
+ for ii=1,thickness do
+ local tempf = fanin
+ for jj=1,(#conMatrixL)[1] do
+ if conMatrixL[jj][2] == ii then
+ if tempf <= 0 then break end
+ revTable[ii][tempf][1] = conMatrixL[jj][1]
+ revTable[ii][tempf][2] = jj
+ tempf = tempf - 1
+ end
+ end
+ if tempf ~= 0 then
+ fanin = -1
+ break
+ end
+ end
+ if fanin ~= -1 then
+ return revTable
+ end
+ end
+ return {}
+end
+
function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH)
parent.__init(self)
@@ -65,9 +96,9 @@ function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH)
self.dW = dW
self.dH = dH
self.connTable = conMatrix
+ self.connTableRev = constructTableRev(conMatrix)
self.nInputPlane = self.connTable:select(2,1):max()
self.nOutputPlane = self.connTable:select(2,2):max()
-
self.weight = torch.Tensor(self.connTable:size(1), kH, kW)
self.bias = torch.Tensor(self.nOutputPlane)
self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW)