Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua65
1 files changed, 45 insertions, 20 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 78db6af..d728b67 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -48,7 +48,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
print("in order to waste less memory on indexes.")
end
end
-
+
self.nChildNode = nChildNode
self.nParentNode = nParentNode
self.minNodeId = minNodeId
@@ -56,7 +56,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
self.maxParentId = maxParentId
self.maxChildId = maxChildId
self.maxFamily = maxFamily
-
+
-- initialize weights and biases
self.weight = torch.Tensor(self.nChildNode, self.inputSize)
self.bias = torch.Tensor(self.nChildNode)
@@ -64,12 +64,12 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
self.gradWeight = torch.Tensor(self.nChildNode, self.inputSize)
self.gradBias = torch.Tensor(self.nChildNode)
end
-
+
-- contains all childIds
self.childIds = torch.IntTensor(self.nChildNode)
-- contains all parentIds
self.parentIds = torch.IntTensor(parentIds)
-
+
-- index of children by parentId
self.parentChildren = torch.IntTensor(self.maxParentId, 2):fill(-1)
local start = 1
@@ -81,7 +81,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
self.childIds:narrow(1, start, nChildren):copy(children)
start = start + nChildren
end
-
+
-- index of parent by childId
self.childParent = torch.IntTensor(self.maxChildId, 2):fill(-1)
for parentIdx=1,self.parentIds:size(1) do
@@ -97,20 +97,20 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
child[2] = childIdx
end
end
-
- -- used to allocate buffers
+
+ -- used to allocate buffers
-- max nChildren in family path
local maxFamilyPath = -999999999
-- max number of parents
local maxDept = -999999999
local treeSizes = {[rootId] = self.parentChildren[rootId][2]}
local pathSizes = {[rootId] = 1}
- local function getSize(nodeId)
+ local function getSize(nodeId)
local treeSize, pathSize = treeSizes[nodeId], pathSizes[nodeId]
if not treeSize then
local parentId = self.childParent[nodeId][1]
local nChildren = self.parentChildren[nodeId][2]
- treeSize, pathSize = getSize(parentId)
+ treeSize, pathSize = getSize(parentId)
treeSize = treeSize + nChildren
pathSize = pathSize + 1
treeSizes[nodeId] = treeSize
@@ -126,21 +126,21 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver
end
self.maxFamilyPath = maxFamilyPath
self.maxDept = maxDept
-
+
-- stores the parentIds of nodes that have been accGradParameters
self.updates = {}
-
+
-- used internally to store intermediate outputs or gradOutputs
self._nodeBuffer = torch.Tensor()
self._multiBuffer = torch.Tensor()
-
+
self.batchSize = 0
-
+
self._gradInput = torch.Tensor()
self._gradTarget = torch.IntTensor() -- dummy
self.gradInput = {self._gradInput, self._gradTarget}
self.static = (static == nil) and true or static
-
+
self:reset()
end
@@ -162,7 +162,7 @@ function SoftMaxTree:updateOutput(inputTable)
self._multiBuffer:resize(input:size(1)*self.maxFamilyPath)
self.batchSize = input:size(1)
-- so that it works within nn.ConcatTable :
- self._gradTarget:resizeAs(target):zero()
+ self._gradTarget:resizeAs(target):zero()
if self._nodeUpdateHost then
self._nodeUpdateHost:resize(input:size(1),self.maxDept)
self._nodeUpdateCuda:resize(input:size(1),self.maxDept)
@@ -281,7 +281,7 @@ function SoftMaxTree:type(type, typecache)
if type == torch.type(self.weight) then
return self
end
-
+
local hierarchy = self.hierarchy
self.hierarchy = nil
self._nodeUpdateHost = nil
@@ -301,16 +301,16 @@ function SoftMaxTree:type(type, typecache)
local parentIds = self.parentIds
self.parentIds = nil
self._gradOutput = nil
-
+
parent.type(self, type, typecache)
-
+
self.hierarchy = hierarchy
self.parentChildren = parentChildren
self.childParent = childParent
self._gradTarget = _gradTarget
self.childIds = childIds
self.parentIds = parentIds
-
+
if (type == 'torch.CudaTensor') then
-- cunnx needs this for filling self.updates
self._nodeUpdateHost = torch.IntTensor()
@@ -327,7 +327,7 @@ function SoftMaxTree:type(type, typecache)
self.childParent = self.childParent:type('torch.IntTensor')
self._gradTarget = self._gradTarget:type('torch.IntTensor')
end
- self.gradInput = {self._gradInput, self._gradTarget}
+ self.gradInput = {self._gradInput, self._gradTarget}
self.batchSize = 0 --so that buffers are resized
return self
end
@@ -343,5 +343,30 @@ function SoftMaxTree:maxNorm(maxNorm)
end
end
+function SoftMaxTree:momentumGradParameters()
+ -- get dense view of momGradParams
+ local _ = require 'moses'
+ if not self.momGradParams or _.isEmpty(self.momGradParams) then
+ assert(not self.accUpdate, "cannot use momentum with accUpdate")
+ self.momGradParams = {self.gradWeight:clone():zero(), self.gradBias:clone():zero()}
+ end
+ local momGradParams = self.momGradParams
+ if self.static and not _.isEmpty(self.updates) then
+ local momGradWeight = momGradParams[1]
+ local momGradBias = momGradParams[2]
+ momGradParams = {}
+ -- only return the parameters affected by the forward/backward
+ for parentId, scale in pairs(self.updates) do
+ local node = self.parentChildren:select(1, parentId)
+ local parentIdx = node[1]
+ local nChildren = node[2]
+ momGradParams[parentId] = momGradWeight:narrow(1, parentIdx, nChildren)
+ local biasId = parentId+self.maxParentId
+ momGradParams[biasId] = momGradBias:narrow(1, parentIdx, nChildren)
+ end
+ end
+ return momGradParams
+end
+
-- we do not need to accumulate parameters when sharing
SoftMaxTree.sharedAccUpdateGradParameters = SoftMaxTree.accUpdateGradParameters