diff options
-rw-r--r-- | Module.lua | 89 | ||||
-rw-r--r-- | SpatialConvolutionMap.lua | 33 | ||||
-rw-r--r-- | dok/index.dok | 3 | ||||
-rw-r--r-- | test/test.lua | 99 |
4 files changed, 165 insertions, 59 deletions
@@ -139,70 +139,47 @@ function Module:getParameters() -- get parameters local parameters,gradParameters = self:parameters() + local function storageInSet(set, storage) --this is waste of time (need correct hash) + for key, val in pairs(set) do + if key == storage then + return val + end + end + end + -- this function flattens arbitrary lists of parameters, -- even complex shared ones local function flatten(parameters) - -- already flat ? - local flat = true - for k = 2,#parameters do - if parameters[k]:storage() ~= parameters[k-1]:storage() then - flat = false - break + local storages = {} + local nParameters = 0 + for k = 1,#parameters do + if not storageInSet(storages, parameters[k]:storage()) then + storages[parameters[k]:storage()] = nParameters + nParameters = nParameters + parameters[k]:storage():size() end end - if flat then - local nParameters = 0 - for k,param in ipairs(parameters) do - nParameters = nParameters + param:nElement() - end - local flatParameters = parameters[1].new(parameters[1]:storage()) - if nParameters ~= flatParameters:nElement() then - error('flattenParameters(): weird parameters') - end - return flatParameters + + local flatParameters = torch.Tensor(nParameters):fill(1) + local flatStorage = flatParameters:storage() + + for k = 1,#parameters do + local storageOffset = storageInSet(storages, parameters[k]:storage()) + parameters[k]:set(flatStorage, + storageOffset + parameters[k]:storageOffset(), + parameters[k]:size(), + parameters[k]:stride()) + parameters[k]:zero() end - -- compute offsets of each parameter - local offsets = {} - local sizes = {} - local strides = {} - local elements = {} - local storageOffsets = {} - local params = {} - local nParameters = 0 - for k,param in ipairs(parameters) do - table.insert(offsets, nParameters+1) - table.insert(sizes, param:size()) - table.insert(strides, param:stride()) - table.insert(elements, param:nElement()) - table.insert(storageOffsets, param:storageOffset()) - local isView = false - for i = 1,k-1 do - if param:storage() == parameters[i]:storage() then - offsets[k] = offsets[i] - if storageOffsets[k] ~= storageOffsets[i] or elements[k] ~= elements[i] then - error('flattenParameters(): cannot flatten shared weights with different structures') - end - isView = true - break - end - end - if not isView then - nParameters = nParameters + param:nElement() - end + if (flatParameters:sum() ~= 0) then + print("<getParameters()> WARNING: found " + .. flatParameters:sum() .. " holes in the parameters vector (i.e. " + .. flatParameters:sum() .. " storage elements that are unused, this " + .. "might be an issue for your optimization procedure)") end - -- create flat vector - local flatParameters = parameters[1].new(nParameters) - local storage = flatParameters:storage() - -- reallocate all parameters in flat vector - for i = 1,#parameters do - local data = parameters[i]:clone() - parameters[i]:set(storage, offsets[i], elements[i]):resize(sizes[i],strides[i]):copy(data) - data = nil - collectgarbage() + + for k, v in pairs(storages) do + flatParameters[{{v+1,v+k:size()}}]:copy(torch.Tensor():set(k)) end - -- cleanup - collectgarbage() - -- return flat param return flatParameters end diff --git a/SpatialConvolutionMap.lua b/SpatialConvolutionMap.lua index 4b525ba..11718fd 100644 --- a/SpatialConvolutionMap.lua +++ b/SpatialConvolutionMap.lua @@ -54,6 +54,37 @@ function nn.tables.random(nin, nout, nto) return tbl end +function constructTableRev(conMatrix) + local conMatrixL = conMatrix:type('torch.LongTensor') + -- Construct reverse lookup connection table + local thickness = conMatrixL:select(2,2):max() + -- approximate fanin check + if (#conMatrixL)[1] % thickness == 0 then + -- do a proper fanin check and set revTable + local fanin = (#conMatrixL)[1] / thickness + local revTable = torch.Tensor(thickness, fanin, 2) + for ii=1,thickness do + local tempf = fanin + for jj=1,(#conMatrixL)[1] do + if conMatrixL[jj][2] == ii then + if tempf <= 0 then break end + revTable[ii][tempf][1] = conMatrixL[jj][1] + revTable[ii][tempf][2] = jj + tempf = tempf - 1 + end + end + if tempf ~= 0 then + fanin = -1 + break + end + end + if fanin ~= -1 then + return revTable + end + end + return {} +end + function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH) parent.__init(self) @@ -65,9 +96,9 @@ function SpatialConvolutionMap:__init(conMatrix, kW, kH, dW, dH) self.dW = dW self.dH = dH self.connTable = conMatrix + self.connTableRev = constructTableRev(conMatrix) self.nInputPlane = self.connTable:select(2,1):max() self.nOutputPlane = self.connTable:select(2,2):max() - self.weight = torch.Tensor(self.connTable:size(1), kH, kW) self.bias = torch.Tensor(self.nOutputPlane) self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW) diff --git a/dok/index.dok b/dok/index.dok index a687db3..8e27fa4 100644 --- a/dok/index.dok +++ b/dok/index.dok @@ -268,6 +268,9 @@ situations. Keep in mind that, this function uses a simple trick to achieve its goal and it might not be valid for a custom module. +Also note that compared to accGradParameters(), the gradients are not retained +for future use. + <file lua> function Module:accUpdateGradParameters(input, gradOutput, lr) local gradWeight = self.gradWeight diff --git a/test/test.lua b/test/test.lua index 4d4383d..ff8b0d9 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1078,10 +1078,105 @@ function nntest.VolumetricConvolution() mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') end +function nntest.Module_getParameters_1() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'getParameters(): weights wrong') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'getParameters(): bias wrong') +end + +function nntest.Module_getParameters_2() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + local p = n:getParameters() + + n:add( nn.Linear(10,10) ) + p = n:getParameters() + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when appending new module') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when appending new module') +end + +function nntest.Module_getParameters_3() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone() ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + n:reset() + + mytester:assertgt((p[{ {111,210} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:assertgt((p[{ {211,220} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') +end + +function nntest.Module_getParameters_4() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone() ) + local p = n:getParameters() + + n:add(nn.Linear(10,10)) + p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {111,210} }] - n.modules[2].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {211,220} }] - n.modules[2].bias):norm(), 0, 'error when using cloning') + + mytester:asserteq((p[{ {221,320} }] - n.modules[3].weight):norm(), 0, 'error when using cloning') + mytester:asserteq((p[{ {321,330} }] - n.modules[3].bias):norm(), 0, 'error when using cloning') +end + +function nntest.Module_getParameters_5() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone('weight','bias') ) + local p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') + + n:reset() + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') +end + +function nntest.Module_getParameters_6() + local n = nn.Sequential() + n:add( nn.Linear(10,10) ) + n:add( n.modules[1]:clone('weight','bias') ) + local p = n:getParameters() + + n:add(nn.Linear(10,10)) + p = n:getParameters() + + mytester:asserteq((p[{ {1,100} }] - n.modules[1].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[1].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {1,100} }] - n.modules[2].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {101,110} }] - n.modules[2].bias):norm(), 0, 'error when using cloning+sharing') + + mytester:asserteq((p[{ {111,210} }] - n.modules[3].weight):norm(), 0, 'error when using cloning+sharing') + mytester:asserteq((p[{ {211,220} }] - n.modules[3].bias):norm(), 0, 'error when using cloning+sharing') +end mytester:add(nntest) ---mytester:add(test_SpatialConvolution) ---mytester:add(test_AbsCriterion) if not nn then require 'nn' |