Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-02-16 19:43:53 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-02-16 19:43:53 +0300
commitff7a90e5d2717ec9f543121f95973ab550a37189 (patch)
tree4b755f33ccc56b52c4e849d6c8d6336560bea155
parentc3df4fbe7b7e18c41f120d461e08be8373ba8376 (diff)
SoftMaxForest + unit test fixes
-rw-r--r--SoftMaxForest.lua55
-rw-r--r--SoftMaxTree.lua65
-rw-r--r--init.lua11
-rw-r--r--test/test-all.lua62
4 files changed, 137 insertions, 56 deletions
diff --git a/SoftMaxForest.lua b/SoftMaxForest.lua
new file mode 100644
index 0000000..0fee1fd
--- /dev/null
+++ b/SoftMaxForest.lua
@@ -0,0 +1,55 @@
+local SoftMaxForest, parent = torch.class("nn.SoftMaxForest", "nn.Container")
+
+function SoftMaxForest:__init(inputSize, trees, rootIds, gaterSize, gaterAct, accUpdate)
+ local gaterAct = gaterAct or nn.Tanh()
+ local gaterSize = gaterSize or {}
+
+ -- experts
+ self.experts = nn.ConcatTable()
+ self.smts = {}
+ for i,tree in ipairs(trees) do
+ local smt = nn.SoftMaxTree(inputSize, tree, rootIds[i], accUpdate)
+ table.insert(self._smts, smt)
+ self.experts:add(smt)
+ end
+
+ -- gater
+ self.gater = nn.Sequential()
+ self.gater:add(nn.SelectTable(1)) -- ignore targets
+ for i,hiddenSize in ipairs(gaterSize) do
+ self.gater:add(nn.Linear(inputSize, hiddenSize))
+ self.gater:add(gaterAct:clone())
+ inputSize = hiddenSize
+ end
+ self.gater:add(nn.Linear(inputSize, self.experts:size()))
+ self.gater:add(nn.SoftMax())
+
+ -- mixture
+ self.trunk = nn.ConcatTable()
+ self.trunk:add(self._gater)
+ self.trunk:add(self._experts)
+ self.mixture = nn.MixtureTable()
+ self.module = nn.Sequential()
+ self.module:add(self.trunk)
+ self.module:add(self.mixture)
+ parent.__init(self)
+ self.modules[1] = self.module
+end
+
+function SoftMaxForest:updateOutput(input)
+ self.output = self.module:updateOutput(input)
+ return self.output
+end
+
+function SoftMaxForest:updateGradInput(input, gradOutput)
+ self.gradInput = self.module:updateGradInput(input, gradOutput)
+ return self.gradInput
+end
+
+function SoftMaxForest:accGradParameters(input, gradOutput, scale)
+ self.module:accGradParameters(input, gradOutput, scale)
+end
+
+function SoftMaxForest:accUpdateGradParameters(input, gradOutput, lr)
+ self.module:accUpdateGradParameters(input, gradOutput, lr)
+end
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 78db6af..d728b67 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -48,7 +48,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
print("in order to waste less memory on indexes.")
end
end
-
+
self.nChildNode = nChildNode
self.nParentNode = nParentNode
self.minNodeId = minNodeId
@@ -56,7 +56,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
self.maxParentId = maxParentId
self.maxChildId = maxChildId
self.maxFamily = maxFamily
-
+
-- initialize weights and biases
self.weight = torch.Tensor(self.nChildNode, self.inputSize)
self.bias = torch.Tensor(self.nChildNode)
@@ -64,12 +64,12 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
self.gradWeight = torch.Tensor(self.nChildNode, self.inputSize)
self.gradBias = torch.Tensor(self.nChildNode)
end
-
+
-- contains all childIds
self.childIds = torch.IntTensor(self.nChildNode)
-- contains all parentIds
self.parentIds = torch.IntTensor(parentIds)
-
+
-- index of children by parentId
self.parentChildren = torch.IntTensor(self.maxParentId, 2):fill(-1)
local start = 1
@@ -81,7 +81,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
self.childIds:narrow(1, start, nChildren):copy(children)
start = start + nChildren
end
-
+
-- index of parent by childId
self.childParent = torch.IntTensor(self.maxChildId, 2):fill(-1)
for parentIdx=1,self.parentIds:size(1) do
@@ -97,20 +97,20 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
child[2] = childIdx
end
end
-
- -- used to allocate buffers
+
+ -- used to allocate buffers
-- max nChildren in family path
local maxFamilyPath = -999999999
-- max number of parents
local maxDept = -999999999
local treeSizes = {[rootId] = self.parentChildren[rootId][2]}
local pathSizes = {[rootId] = 1}
- local function getSize(nodeId)
+ local function getSize(nodeId)
local treeSize, pathSize = treeSizes[nodeId], pathSizes[nodeId]
if not treeSize then
local parentId = self.childParent[nodeId][1]
local nChildren = self.parentChildren[nodeId][2]
- treeSize, pathSize = getSize(parentId)
+ treeSize, pathSize = getSize(parentId)
treeSize = treeSize + nChildren
pathSize = pathSize + 1
treeSizes[nodeId] = treeSize
@@ -126,21 +126,21 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
end
self.maxFamilyPath = maxFamilyPath
self.maxDept = maxDept
-
+
-- stores the parentIds of nodes that have been accGradParameters
self.updates = {}
-
+
-- used internally to store intermediate outputs or gradOutputs
self._nodeBuffer = torch.Tensor()
self._multiBuffer = torch.Tensor()
-
+
self.batchSize = 0
-
+
self._gradInput = torch.Tensor()
self._gradTarget = torch.IntTensor() -- dummy
self.gradInput = {self._gradInput, self._gradTarget}
self.static = (static == nil) and true or static
-
+
self:reset()
end
@@ -162,7 +162,7 @@ function SoftMaxTree:updateOutput(inputTable)
self._multiBuffer:resize(input:size(1)*self.maxFamilyPath)
self.batchSize = input:size(1)
-- so that it works within nn.ConcatTable :
- self._gradTarget:resizeAs(target):zero()
+ self._gradTarget:resizeAs(target):zero()
if self._nodeUpdateHost then
self._nodeUpdateHost:resize(input:size(1),self.maxDept)
self._nodeUpdateCuda:resize(input:size(1),self.maxDept)
@@ -281,7 +281,7 @@ function SoftMaxTree:type(type, typecache)
if type == torch.type(self.weight) then
return self
end
-
+
local hierarchy = self.hierarchy
self.hierarchy = nil
self._nodeUpdateHost = nil
@@ -301,16 +301,16 @@ function SoftMaxTree:type(type, typecache)
local parentIds = self.parentIds
self.parentIds = nil
self._gradOutput = nil
-
+
parent.type(self, type, typecache)
-
+
self.hierarchy = hierarchy
self.parentChildren = parentChildren
self.childParent = childParent
self._gradTarget = _gradTarget
self.childIds = childIds
self.parentIds = parentIds
-
+
if (type == 'torch.CudaTensor') then
-- cunnx needs this for filling self.updates
self._nodeUpdateHost = torch.IntTensor()
@@ -327,7 +327,7 @@ function SoftMaxTree:type(type, typecache)
self.childParent = self.childParent:type('torch.IntTensor')
self._gradTarget = self._gradTarget:type('torch.IntTensor')
end
- self.gradInput = {self._gradInput, self._gradTarget}
+ self.gradInput = {self._gradInput, self._gradTarget}
self.batchSize = 0 --so that buffers are resized
return self
end
@@ -343,5 +343,30 @@ function SoftMaxTree:maxNorm(maxNorm)
end
end
+function SoftMaxTree:momentumGradParameters()
+ -- get dense view of momGradParams
+ local _ = require 'moses'
+ if not self.momGradParams or _.isEmpty(self.momGradParams) then
+ assert(not self.accUpdate, "cannot use momentum with accUpdate")
+ self.momGradParams = {self.gradWeight:clone():zero(), self.gradBias:clone():zero()}
+ end
+ local momGradParams = self.momGradParams
+ if self.static and not _.isEmpty(self.updates) then
+ local momGradWeight = momGradParams[1]
+ local momGradBias = momGradParams[2]
+ momGradParams = {}
+ -- only return the parameters affected by the forward/backward
+ for parentId, scale in pairs(self.updates) do
+ local node = self.parentChildren:select(1, parentId)
+ local parentIdx = node[1]
+ local nChildren = node[2]
+ momGradParams[parentId] = momGradWeight:narrow(1, parentIdx, nChildren)
+ local biasId = parentId+self.maxParentId
+ momGradParams[biasId] = momGradBias:narrow(1, parentIdx, nChildren)
+ end
+ end
+ return momGradParams
+end
+
-- we do not need to accumulate parameters when sharing
SoftMaxTree.sharedAccUpdateGradParameters = SoftMaxTree.accUpdateGradParameters
diff --git a/init.lua b/init.lua
index 4abe66a..f40bbe9 100644
--- a/init.lua
+++ b/init.lua
@@ -1,9 +1,9 @@
----------------------------------------------------------------------
--
--- Copyright (c) 2011 Clement Farabet, Marco Scoffier,
+-- Copyright (c) 2011 Clement Farabet, Marco Scoffier,
-- Koray Kavukcuoglu, Benoit Corda
--
---
+--
-- Permission is hereby granted, free of charge, to any person obtaining
-- a copy of this software and associated documentation files (the
-- "Software"), to deal in the Software without restriction, including
@@ -11,10 +11,10 @@
-- distribute, sublicense, and/or sell copies of the Software, and to
-- permit persons to whom the Software is furnished to do so, subject to
-- the following conditions:
---
+--
-- The above copyright notice and this permission notice shall be
-- included in all copies or substantial portions of the Software.
---
+--
-- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
-- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
@@ -22,7 +22,7 @@
-- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
-- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
---
+--
----------------------------------------------------------------------
require 'torch'
@@ -69,6 +69,7 @@ require('nnx.FunctionWrapper')
require('nnx.SaturatedLU')
require('nnx.Minus')
require('nnx.SoftMaxTree')
+require('nnx.SoftMaxForest')
require('nnx.MultiSoftMax')
require('nnx.Balance')
require('nnx.PushTable')
diff --git a/test/test-all.lua b/test/test-all.lua
index 80ed910..b284cd4 100644
--- a/test/test-all.lua
+++ b/test/test-all.lua
@@ -3,7 +3,7 @@ local nnxtest = {}
local precision = 1e-5
local mytester
--- you can easily test specific units like this:
+-- you can easily test specific units like this:
-- th -lnnx -e "nnx.test{'MultiSoftMax'}"
-- th -lnnx -e "nnx.test{'SoftMaxTree', 'Balance'}"
@@ -16,7 +16,7 @@ function nnxtest.SpatialPadding()
local pad_t = math.random(0,8)
local pad_b = math.random(0,8)
local val = torch.randn(1):squeeze()
- local module = nn.SpatialPadding(pad_l, pad_r, pad_t, pad_b, val)
+ local module = nn.SpatialPadding(pad_l, pad_r, pad_t, pad_b, nil, nil, val)
local input = torch.rand(fanin,sizey,sizex)
local err = nn.Jacobian.testJacobian(module, input)
@@ -82,10 +82,10 @@ end
local function template_SpatialReSamplingEx(up, mode)
for iTest = 1,3 do
- local nDims = math.random(2,6)
+ local nDims = math.random(2,3)
local dims = torch.LongStorage(nDims)
for i = 1,nDims do
- dims[i] = math.random(5,20/nDims)
+ dims[i] = math.random(5,torch.round(20/nDims))
end
local xratio, yratio
if up then
@@ -102,10 +102,10 @@ local function template_SpatialReSamplingEx(up, mode)
local module = nn.SpatialReSamplingEx({owidth=owidth_, oheight=oheight_,
xDim=xdim, yDim = ydim, mode=mode})
local input = torch.rand(dims)
-
+
local err = nn.Jacobian.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state ')
-
+
local ferr, berr = nn.Jacobian.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
@@ -137,8 +137,8 @@ end
function nnxtest.SpatialDownSampling()
local fanin = math.random(1,4)
- local sizex = math.random(11,4)
- local sizey = math.random(11,4)
+ local sizex = math.random(4,11)
+ local sizey = math.random(4,11)
local mx = math.random(2,6)
local my = math.random(2,6)
local module = nn.SpatialDownSampling(mx,my)
@@ -172,21 +172,21 @@ function nnxtest.SpatialReSampling_1()
local batchSize = math.random(4,8)
local input2 = torch.rand(batchSize,fanin,sizey,sizex)
input2[2]:copy(input)
-
+
local output = module:forward(input):clone()
local output2 = module:forward(input2)
mytester:assertTensorEq(output, output2[2], 0.00001, 'SpatialResampling batch forward err')
-
+
local gradInput = module:backward(input, output):clone()
local gradInput2 = module:backward(input2, output2)
mytester:assertTensorEq(gradInput, gradInput2[2], 0.00001, 'SpatialResampling batch backward err')
-
+
-- test rwidth/rheight
local input = torch.randn(3,8,10)
local module = nn.SpatialReSampling{rwidth=0.5,rheight=0.5}
local output = module:forward(input)
mytester:assertTableEq(output:size():totable(), {3, 4, 5}, 0.00000001, 'SpatialResampling batch rwidth/rheight err')
-
+
local input = torch.randn(2,3,8,10)
local module = nn.SpatialReSampling{rwidth=0.5,rheight=0.5}
local output = module:forward(input)
@@ -408,7 +408,7 @@ local function template_SpatialMatching(channels, iwidth, iheight, maxw, maxh, f
local input = torch.rand(2, channels, iheight, iwidth)
local err = nn.Jacobian.testJacobian(module, input)
mytester:assertlt(err, precision, 'error on state ')
-
+
local ferr, berr = nn.Jacobian.testIO(module, input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
@@ -426,7 +426,7 @@ function nnxtest.SoftMaxTree()
local grad = torch.randn(5)
local root_id = 29
local hierarchy={
- [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
+ [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
[2]=torch.IntTensor{6,7,8}, [3]=torch.IntTensor{9,10,11},
[4]=torch.IntTensor{12,13,14}, [5]=torch.IntTensor{15,16,17},
[6]=torch.IntTensor{18,19,20}, [7]=torch.IntTensor{21,22,23},
@@ -439,7 +439,7 @@ function nnxtest.SoftMaxTree()
local indices = {3,3,4}
local parentIds = {29,2,8}
local linears = {}
-
+
for i,parentId in ipairs(parentIds) do
local s = nn.Sequential()
local linear = nn.Linear(100,hierarchy[parentId]:size(1))
@@ -512,7 +512,7 @@ end
function nnxtest.TreeNLLCriterion()
local input = torch.randn(5,10)
local target = torch.ones(5) --all targets are 1
- local c = nn.TreeNLLCriterion()
+ local c = nn.TreeNLLCriterion()
-- the targets are actually ignored (SoftMaxTree uses them before TreeNLLCriterion)
local err = c:forward(input, target)
gradInput = c:backward(input, target)
@@ -577,10 +577,10 @@ local function blur(mean, stdv, size)
end
function nnxtest.Balance()
- local inputSize = 7
+ local inputSize = 7
local batchSize = 3
local nBatch = 1
-
+
local input = torch.randn(batchSize, inputSize):mul(0.1):float()
for i=1,batchSize do
input[i]:add(blur(3, 1, inputSize):float())
@@ -591,36 +591,36 @@ function nnxtest.Balance()
local gradOutput = torch.randn(batchSize, inputSize):float()
local bl = nn.Balance(nBatch)
bl:float()
-
+
local output = bl:forward(input)
local p_y = output:sum(1):div(output:sum())
mytester:assert(p_y:std() < 0.02)
mytester:assert(math.abs(p_y:sum() - 1) < 0.000001)
-
+
local gradInput = bl:backward(input, gradOutput)
end
function nnxtest.MultiSoftMax()
- local inputSize = 7
+ local inputSize = 7
local nSoftmax = 5
local batchSize = 3
-
+
local input = torch.randn(batchSize, nSoftmax, inputSize)
local gradOutput = torch.randn(batchSize, nSoftmax, inputSize)
local msm = nn.MultiSoftMax()
-
+
local output = msm:forward(input)
local gradInput = msm:backward(input, gradOutput)
mytester:assert(output:isSameSizeAs(input))
mytester:assert(gradOutput:isSameSizeAs(gradInput))
-
+
local sm = nn.SoftMax()
local input2 = input:view(batchSize*nSoftmax, inputSize)
local output2 = sm:forward(input2)
local gradInput2 = sm:backward(input2, gradOutput:view(batchSize*nSoftmax, inputSize))
-
- mytester:assertTensorEq(output, output2, 0.000001)
- mytester:assertTensorEq(gradInput, gradInput2, 0.000001)
+
+ mytester:assertTensorEq(output:view(-1), output2:view(-1), 0.000001)
+ mytester:assertTensorEq(gradInput:view(-1), gradInput2:view(-1), 0.000001)
end
function nnxtest.PushPullTable()
@@ -630,14 +630,14 @@ function nnxtest.PushPullTable()
local gradOutput = torch.randn(5)
local root_id = 29
local hierarchy={
- [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
+ [29]=torch.IntTensor{30,1,2}, [1]=torch.IntTensor{3,4,5},
[2]=torch.IntTensor{6,7,8}, [3]=torch.IntTensor{9,10,11},
[4]=torch.IntTensor{12,13,14}, [5]=torch.IntTensor{15,16,17},
[6]=torch.IntTensor{18,19,20}, [7]=torch.IntTensor{21,22,23},
[8]=torch.IntTensor{24,25,26,27,28}
}
local smt = nn.SoftMaxTree(100, hierarchy, root_id)
- -- create a network where inputs are fed through softmaxtree
+ -- create a network where inputs are fed through softmaxtree
-- and targets are teleported (pushed then pulled) to softmaxtree
local mlp = nn.Sequential()
local linear = nn.Linear(50,100)
@@ -663,7 +663,7 @@ function nnxtest.PushPullTable()
mytester:assertTensorEq(output, output2, 0.00001, "push/pull forward error")
mytester:assertTensorEq(gradInput[1], gradInput[1], 0.00001, "push/pull backward error")
mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull backward error")
-
+
-- test multi-pull case
local mlp = nn.Sequential()
local push = nn.PushTable(2)
@@ -680,7 +680,7 @@ function nnxtest.PushPullTable()
mytester:assertTensorEq(output[4], inputTable[2], 0.00001, "push/pull multi-forward error")
local gradOutput = {inputTable[2]:clone(), inputTable[1]:clone(), inputTable[2]:clone(), inputTable[2]:clone()}
local gradInput = mlp:backward(inputTable, gradOutput)
- local gradInput2 = inputTable[2]:clone():mul(3)
+ local gradInput2 = inputTable[2]:clone():mul(3)
mytester:assertTensorEq(gradInput[1], gradInput[1], 0.00001, "push/pull multi-backward error")
mytester:assertTensorEq(gradInput[2], gradInput[2], 0.00001, "push/pull multi-backward error")
end