From ff7a90e5d2717ec9f543121f95973ab550a37189 Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Thu, 16 Feb 2017 11:43:53 -0500 Subject: SoftMaxForest + unit test fixes --- SoftMaxForest.lua | 55 ++++++++++++++++++++++++++++++++++++++++++++++ SoftMaxTree.lua | 65 ++++++++++++++++++++++++++++++++++++++----------------- init.lua | 11 +++++----- test/test-all.lua | 62 ++++++++++++++++++++++++++-------------------------- 4 files changed, 137 insertions(+), 56 deletions(-) create mode 100644 SoftMaxForest.lua diff --git a/SoftMaxForest.lua b/SoftMaxForest.lua new file mode 100644 index 0000000..0fee1fd --- /dev/null +++ b/SoftMaxForest.lua @@ -0,0 +1,55 @@ +local SoftMaxForest, parent = torch.class("nn.SoftMaxForest", "nn.Container") + +function SoftMaxForest:__init(inputSize, trees, rootIds, gaterSize, gaterAct, accUpdate) + local gaterAct = gaterAct or nn.Tanh() + local gaterSize = gaterSize or {} + + -- experts + self.experts = nn.ConcatTable() + self.smts = {} + for i,tree in ipairs(trees) do + local smt = nn.SoftMaxTree(inputSize, tree, rootIds[i], accUpdate) + table.insert(self._smts, smt) + self.experts:add(smt) + end + + -- gater + self.gater = nn.Sequential() + self.gater:add(nn.SelectTable(1)) -- ignore targets + for i,hiddenSize in ipairs(gaterSize) do + self.gater:add(nn.Linear(inputSize, hiddenSize)) + self.gater:add(gaterAct:clone()) + inputSize = hiddenSize + end + self.gater:add(nn.Linear(inputSize, self.experts:size())) + self.gater:add(nn.SoftMax()) + + -- mixture + self.trunk = nn.ConcatTable() + self.trunk:add(self._gater) + self.trunk:add(self._experts) + self.mixture = nn.MixtureTable() + self.module = nn.Sequential() + self.module:add(self.trunk) + self.module:add(self.mixture) + parent.__init(self) + self.modules[1] = self.module +end + +function SoftMaxForest:updateOutput(input) + self.output = self.module:updateOutput(input) + return self.output +end + +function SoftMaxForest:updateGradInput(input, gradOutput) + self.gradInput = self.module:updateGradInput(input, gradOutput) + return self.gradInput +end + +function SoftMaxForest:accGradParameters(input, gradOutput, scale) + self.module:accGradParameters(input, gradOutput, scale) +end + +function SoftMaxForest:accUpdateGradParameters(input, gradOutput, lr) + self.module:accUpdateGradParameters(input, gradOutput, lr) +end diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 78db6af..d728b67 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -48,7 +48,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver print("in order to waste less memory on indexes.") end end - + self.nChildNode = nChildNode self.nParentNode = nParentNode self.minNodeId = minNodeId @@ -56,7 +56,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver self.maxParentId = maxParentId self.maxChildId = maxChildId self.maxFamily = maxFamily - + -- initialize weights and biases self.weight = torch.Tensor(self.nChildNode, self.inputSize) self.bias = torch.Tensor(self.nChildNode) @@ -64,12 +64,12 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver self.gradWeight = torch.Tensor(self.nChildNode, self.inputSize) self.gradBias = torch.Tensor(self.nChildNode) end - + -- contains all childIds self.childIds = torch.IntTensor(self.nChildNode) -- contains all parentIds self.parentIds = torch.IntTensor(parentIds) - + -- index of children by parentId self.parentChildren = torch.IntTensor(self.maxParentId, 2):fill(-1) local start = 1 @@ -81,7 +81,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver self.childIds:narrow(1, start, nChildren):copy(children) start = start + nChildren end - + -- index of parent by childId self.childParent = torch.IntTensor(self.maxChildId, 2):fill(-1) for parentIdx=1,self.parentIds:size(1) do @@ -97,20 +97,20 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver child[2] = childIdx end end - - -- used to allocate buffers + + -- used to allocate buffers -- max nChildren in family path local maxFamilyPath = -999999999 -- max number of parents local maxDept = -999999999 local treeSizes = {[rootId] = self.parentChildren[rootId][2]} local pathSizes = {[rootId] = 1} - local function getSize(nodeId) + local function getSize(nodeId) local treeSize, pathSize = treeSizes[nodeId], pathSizes[nodeId] if not treeSize then local parentId = self.childParent[nodeId][1] local nChildren = self.parentChildren[nodeId][2] - treeSize, pathSize = getSize(parentId) + treeSize, pathSize = getSize(parentId) treeSize = treeSize + nChildren pathSize = pathSize + 1 treeSizes[nodeId] = treeSize @@ -126,21 +126,21 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver end self.maxFamilyPath = maxFamilyPath self.maxDept = maxDept - + -- stores the parentIds of nodes that have been accGradParameters self.updates = {} - + -- used internally to store intermediate outputs or gradOutputs self._nodeBuffer = torch.Tensor() self._multiBuffer = torch.Tensor() - + self.batchSize = 0 - + self._gradInput = torch.Tensor() self._gradTarget = torch.IntTensor() -- dummy self.gradInput = {self._gradInput, self._gradTarget} self.static = (static == nil) and true or static - + self:reset() end @@ -162,7 +162,7 @@ function SoftMaxTree:updateOutput(inputTable) self._multiBuffer:resize(input:size(1)*self.maxFamilyPath) self.batchSize = input:size(1) -- so that it works within nn.ConcatTable : - self._gradTarget:resizeAs(target):zero() + self._gradTarget:resizeAs(target):zero() if self._nodeUpdateHost then self._nodeUpdateHost:resize(input:size(1),self.maxDept) self._nodeUpdateCuda:resize(input:size(1),self.maxDept) @@ -281,7 +281,7 @@ function SoftMaxTree:type(type, typecache) if type == torch.type(self.weight) then return self end - + local hierarchy = self.hierarchy self.hierarchy = nil self._nodeUpdateHost = nil @@ -301,16 +301,16 @@ function SoftMaxTree:type(type, typecache) local parentIds = self.parentIds self.parentIds = nil self._gradOutput = nil - + parent.type(self, type, typecache) - + self.hierarchy = hierarchy self.parentChildren = parentChildren self.childParent = childParent self._gradTarget = _gradTarget self.childIds = childIds self.parentIds = parentIds - + if (type == 'torch.CudaTensor') then -- cunnx needs this for filling self.updates self._nodeUpdateHost = torch.IntTensor() @@ -327,7 +327,7 @@ function SoftMaxTree:type(type, typecache) self.childParent = self.childParent:type('torch.IntTensor') self._gradTarget = self._gradTarget:type('torch.IntTensor') end - self.gradInput = {self._gradInput, self._gradTarget} + self.gradInput = {self._gradInput, self._gradTarget} self.batchSize = 0 --so that buffers are resized return self end @@ -343,5 +343,30 @@ function SoftMaxTree:maxNorm(maxNorm) end end +function SoftMaxTree:momentumGradParameters() + -- get dense view of momGradParams + local _ = require 'moses' + if not self.momGradParams or _.isEmpty(self.momGradParams) then + assert(not self.accUpdate, "cannot use momentum with accUpdate") + self.momGradParams = {self.gradWeight:clone():zero(), self.gradBias:clone():zero()} + end + local momGradParams = self.momGradParams + if self.static and not _.isEmpty(self.updates) then + local momGradWeight = momGradParams[1] + local momGradBias = momGradParams[2] + momGradParams = {} + -- only return the parameters affected by the forward/backward + for parentId, scale in pairs(self.updates) do + local node = self.parentChildren:select(1, parentId) + local parentIdx = node[1] + local nChildren = node[2] + momGradParams[parentId] = momGradWeight:narrow(1, parentIdx, nChildren) + local biasId = parentId+self.maxParentId + momGradParams[biasId] = momGradBias:narrow(1, parentIdx, nChildren) + end + end + return momGradParams +end + -- we do not need to accumulate parameters when sharing SoftMaxTree.sharedAccUpdateGradParameters = SoftMaxTree.accUpdateGradParameters diff --git a/init.lua b/init.lua index 4abe66a..f40bbe9 100644 --- a/init.lua +++ b/init.lua @@ -1,9 +1,9 @@ ---------------------------------------------------------------------- -- --- Copyright (c) 2011 Clement Farabet, Marco Scoffier, +-- Copyright (c) 2011 Clement Farabet, Marco Scoffier, -- Koray Kavukcuoglu, Benoit Corda -- --- +-- -- Permission is hereby granted, free of charge, to any person obtaining -- a copy of this software and associated documentation files (the -- "Software"), to deal in the Software without restriction, including @@ -11,10 +11,10 @@ -- distribute, sublicense, and/or sell copies of the Software, and to -- permit persons to whom the Software is furnished to do so, subject to -- the following conditions: --- +-- -- The above copyright notice and this permission notice shall be -- included in all copies or substantial portions of the Software. --- +-- -- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND @@ -22,7 +22,7 @@ -- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --- +-- ---------------------------------------------------------------------- require 'torch' @@ -69,6 +69,7 @@ require('nnx.FunctionWrapper') require('nnx.SaturatedLU') require('nnx.Minus') require('nnx.SoftMaxTree') +require('nnx.SoftMaxForest') require('nnx.MultiSoftMax') require('nnx.Balance') require('nnx.PushTable') diff --git a/test/test-all.lua b/test/test-all.lua index 80ed910..b284cd4 100644 --- a/test/test-all.lua +++ b/test/test-all.lua @@ -3,7 +3,7 @@ local nnxtest = {} local precision = 1e-5 local mytester --- you can easily test specific units like this: +-- you can easily test specific units like this: -- th -lnnx -e "nnx.test{'MultiSoftMax'}" -- th -lnnx -e "nnx.test{'SoftMaxTree', 'Balance'}" @@ -16,7 +16,7 @@ function nnxtest.SpatialPadding() local pad_t = math.random(0,8) local pad_b = math.random(0,8) local val = torch.randn(1):squeeze() - local module = nn.SpatialPadding(pad_l, pad_r, pad_t, pad_b, val) + local module = nn.SpatialPadding(pad_l, pad_r, pad_t, pad_b, nil, nil, val) local input = torch.rand(fanin,sizey,sizex) local err = nn.Jacobian.testJacobian(module, input) @@ -82,10 +82,10 @@ end local function template_SpatialReSamplingEx(up, mode) for iTest = 1,3 do - local nDims = math.random(2,6) + local nDims = math.random(2,3) local dims = torch.LongStorage(nDims) for i = 1,nDims do - dims[i] = math.random(5,20/nDims) + dims[i] = math.random(5,torch.round(20/nDims)) end local xratio, yratio if up then @@ -102,10 +102,10 @@ local function template_SpatialReSamplingEx(up, mode) local module = nn.SpatialReSamplingEx({owidth=owidth_, oheight=oheight_, xDim=xdim, yDim = ydim, mode=mode}) local input = torch.rand(dims) - + local err = nn.Jacobian.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local ferr, berr = nn.Jacobian.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') @@ -137,8 +137,8 @@ end function nnxtest.SpatialDownSampling() local fanin = math.random(1,4) - local sizex = math.random(11,4) - local sizey = math.random(11,4) + local sizex = math.random(4,11) + local sizey = math.random(4,11) local mx = math.random(2,6) local my = math.random(2,6) local module = nn.SpatialDownSampling(mx,my) @@ -172,21 +172,21 @@ function nnxtest.SpatialReSampling_1() local batchSize = math.random(4,8) local input2 = torch.rand(batchSize,fanin,sizey,sizex) input2[2]:copy(input) - + local output = module:forward(input):clone() local output2 = module:forward(input2) mytester:assertTensorEq(output, output2[2], 0.00001, 'SpatialResampling batch forward err') - + local gradInput = module:backward(input, output):clone() local gradInput2 = module:backward(input2, output2) mytester:assertTensorEq(gradInput, gradInput2[2], 0.00001, 'SpatialResampling batch backward err') - + -- test rwidth/rheight local input = torch.randn(3,8,10) local module = nn.SpatialReSampling{rwidth=0.5,rheight=0.5} local output = module:forward(input) mytester:assertTableEq(output:size():totable(), {3, 4, 5}, 0.00000001, 'SpatialResampling batch rwidth/rheight err') - + local input = torch.randn(2,3,8,10) local module = nn.SpatialReSampling{rwidth=0.5,rheight=0.5} local output = module:forward(input) @@ -408,7 +408,7 @@ local function template_SpatialMatching(channels, iwidth, iheight, maxw, maxh, f local input = torch.rand(2, channels, iheight, iwidth) local err = nn.Jacobian.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state ') - + local ferr, berr = nn.Jacobian.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ') mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ') @@ -426,7 +426,7 @@ function nnxtest.SoftMaxTree() local grad = torch.randn(5) local root_id = 29 local hierarchy={ - [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5}, + [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5}, [2]=torch.IntTensor{6,7,8}, [3]=torch.IntTensor{9,10,11}, [4]=torch.IntTensor{12,13,14}, [5]=torch.IntTensor{15,16,17}, [6]=torch.IntTensor{18,19,20}, [7]=torch.IntTensor{21,22,23}, @@ -439,7 +439,7 @@ function nnxtest.SoftMaxTree() local indices = {3,3,4} local parentIds = {29,2,8} local linears = {} - + for i,parentId in ipairs(parentIds) do local s = nn.Sequential() local linear = nn.Linear(100,hierarchy[parentId]:size(1)) @@ -512,7 +512,7 @@ end function nnxtest.TreeNLLCriterion() local input = torch.randn(5,10) local target = torch.ones(5) --all targets are 1 - local c = nn.TreeNLLCriterion() + local c = nn.TreeNLLCriterion() -- the targets are actually ignored (SoftMaxTree uses them before TreeNLLCriterion) local err = c:forward(input, target) gradInput = c:backward(input, target) @@ -577,10 +577,10 @@ local function blur(mean, stdv, size) end function nnxtest.Balance() - local inputSize = 7 + local inputSize = 7 local batchSize = 3 local nBatch = 1 - + local input = torch.randn(batchSize, inputSize):mul(0.1):float() for i=1,batchSize do input[i]:add(blur(3, 1, inputSize):float()) @@ -591,36 +591,36 @@ function nnxtest.Balance() local gradOutput = torch.randn(batchSize, inputSize):float() local bl = nn.Balance(nBatch) bl:float() - + local output = bl:forward(input) local p_y = output:sum(1):div(output:sum()) mytester:assert(p_y:std() < 0.02) mytester:assert(math.abs(p_y:sum() - 1) < 0.000001) - + local gradInput = bl:backward(input, gradOutput) end function nnxtest.MultiSoftMax() - local inputSize = 7 + local inputSize = 7 local nSoftmax = 5 local batchSize = 3 - + local input = torch.randn(batchSize, nSoftmax, inputSize) local gradOutput = torch.randn(batchSize, nSoftmax, inputSize) local msm = nn.MultiSoftMax() - + local output = msm:forward(input) local gradInput = msm:backward(input, gradOutput) mytester:assert(output:isSameSizeAs(input)) mytester:assert(gradOutput:isSameSizeAs(gradInput)) - + local sm = nn.SoftMax() local input2 = input:view(batchSize*nSoftmax, inputSize) local output2 = sm:forward(input2) local gradInput2 = sm:backward(input2, gradOutput:view(batchSize*nSoftmax, inputSize)) - - mytester:assertTensorEq(output, output2, 0.000001) - mytester:assertTensorEq(gradInput, gradInput2, 0.000001) + + mytester:assertTensorEq(output:view(-1), output2:view(-1), 0.000001) + mytester:assertTensorEq(gradInput:view(-1), gradInput2:view(-1), 0.000001) end function nnxtest.PushPullTable() @@ -630,14 +630,14 @@ function nnxtest.PushPullTable() local gradOutput = torch.randn(5) local root_id = 29 local hierarchy={ - [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5}, + [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5}, [2]=torch.IntTensor{6,7,8}, [3]=torch.IntTensor{9,10,11}, [4]=torch.IntTensor{12,13,14}, [5]=torch.IntTensor{15,16,17}, [6]=torch.IntTensor{18,19,20}, [7]=torch.IntTensor{21,22,23}, [8]=torch.IntTensor{24,25,26,27,28} } local smt = nn.SoftMaxTree(100, hierarchy, root_id) - -- create a network where inputs are fed through softmaxtree + -- create a network where inputs are fed through softmaxtree -- and targets are teleported (pushed then pulled) to softmaxtree local mlp = nn.Sequential() local linear = nn.Linear(50,100) @@ -663,7 +663,7 @@ function nnxtest.PushPullTable() mytester:assertTensorEq(output, output2, 0.00001, "push/pull forward error") mytester:assertTensorEq(gradInput[1], gradInput[1], 0.00001, "push/pull backward error") mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull backward error") - + -- test multi-pull case local mlp = nn.Sequential() local push = nn.PushTable(2) @@ -680,7 +680,7 @@ function nnxtest.PushPullTable() mytester:assertTensorEq(output[4], inputTable[2], 0.00001, "push/pull multi-forward error") local gradOutput = {inputTable[2]:clone(), inputTable[1]:clone(), inputTable[2]:clone(), inputTable[2]:clone()} local gradInput = mlp:backward(inputTable, gradOutput) - local gradInput2 = inputTable[2]:clone():mul(3) + local gradInput2 = inputTable[2]:clone():mul(3) mytester:assertTensorEq(gradInput[1], gradInput[1], 0.00001, "push/pull multi-backward error") mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull multi-backward error") end -- cgit v1.2.3