diff options
author | Soumith Chintala <soumith@gmail.com> | 2012-08-24 03:03:40 +0400 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2012-08-24 03:03:40 +0400 |
commit | 3c36e6d5cc010a9e5d41aec92b8620a6dd3073b3 (patch) | |
tree | d06102a402465be4df452bcffaa3a89382788e2f | |
parent | 84c3611ffbd7f60b9c10b2bd8e3dc951f1944992 (diff) |
added cuda convolutionMap
-rw-r--r-- | SpatialConvolutionMap.lua | 33 |
1 files changed, 32 insertions, 1 deletions
diff --git a/SpatialConvolutionMap.lua b/SpatialConvolutionMap.lua index 4b525ba..11718fd 100644 --- a/SpatialConvolutionMap.lua +++ b/SpatialConvolutionMap.lua @@ -54,6 +54,37 @@ function nn.tables.random(nin, nout, nto) return tbl end +function constructTableRev(conMatrix) + local conMatrixL = conMatrix:type('torch.LongTensor') + -- Construct reverse lookup connection table + local thickness = conMatrixL:select(2,2):max() + -- approximate fanin check + if (#conMatrixL)[1] % thickness == 0 then + -- do a proper fanin check and set revTable + local fanin = (#conMatrixL)[1] / thickness + local revTable = torch.Tensor(thickness, fanin, 2) + for ii=1,thickness do + local tempf = fanin + for jj=1,(#conMatrixL)[1] do + if conMatrixL[jj][2] == ii then + if tempf <= 0 then break end + revTable[ii][tempf][1] = conMatrixL[jj][1] + revTable[ii][tempf][2] = jj + tempf = tempf - 1 + end + end + if tempf ~= 0 then + fanin = -1 + break + end + end + if fanin ~= -1 then + return revTable + end + end + return {} +end + function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH) parent.__init(self) @@ -65,9 +96,9 @@ function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH) self.dW = dW self.dH = dH self.connTable = conMatrix + self.connTableRev = constructTableRev(conMatrix) self.nInputPlane = self.connTable:select(2,1):max() self.nOutputPlane = self.connTable:select(2,2):max() - self.weight = torch.Tensor(self.connTable:size(1), kH, kW) self.bias = torch.Tensor(self.nOutputPlane) self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW) |