diff options
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 65 |
1 files changed, 45 insertions, 20 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 78db6af..d728b67 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -48,7 +48,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver print("in order to waste less memory on indexes.") end end - + self.nChildNode = nChildNode self.nParentNode = nParentNode self.minNodeId = minNodeId @@ -56,7 +56,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver self.maxParentId = maxParentId self.maxChildId = maxChildId self.maxFamily = maxFamily - + -- initialize weights and biases self.weight = torch.Tensor(self.nChildNode, self.inputSize) self.bias = torch.Tensor(self.nChildNode) @@ -64,12 +64,12 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver self.gradWeight = torch.Tensor(self.nChildNode, self.inputSize) self.gradBias = torch.Tensor(self.nChildNode) end - + -- contains all childIds self.childIds = torch.IntTensor(self.nChildNode) -- contains all parentIds self.parentIds = torch.IntTensor(parentIds) - + -- index of children by parentId self.parentChildren = torch.IntTensor(self.maxParentId, 2):fill(-1) local start = 1 @@ -81,7 +81,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver self.childIds:narrow(1, start, nChildren):copy(children) start = start + nChildren end - + -- index of parent by childId self.childParent = torch.IntTensor(self.maxChildId, 2):fill(-1) for parentIdx=1,self.parentIds:size(1) do @@ -97,20 +97,20 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver child[2] = childIdx end end - - -- used to allocate buffers + + -- used to allocate buffers -- max nChildren in family path local maxFamilyPath = -999999999 -- max number of parents local maxDept = -999999999 local treeSizes = {[rootId] = self.parentChildren[rootId][2]} local pathSizes = {[rootId] = 1} - local function getSize(nodeId) + local function getSize(nodeId) local treeSize, pathSize = treeSizes[nodeId], pathSizes[nodeId] if not treeSize then local parentId = self.childParent[nodeId][1] local nChildren = self.parentChildren[nodeId][2] - treeSize, pathSize = getSize(parentId) + treeSize, pathSize = getSize(parentId) treeSize = treeSize + nChildren pathSize = pathSize + 1 treeSizes[nodeId] = treeSize @@ -126,21 +126,21 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, ver end self.maxFamilyPath = maxFamilyPath self.maxDept = maxDept - + -- stores the parentIds of nodes that have been accGradParameters self.updates = {} - + -- used internally to store intermediate outputs or gradOutputs self._nodeBuffer = torch.Tensor() self._multiBuffer = torch.Tensor() - + self.batchSize = 0 - + self._gradInput = torch.Tensor() self._gradTarget = torch.IntTensor() -- dummy self.gradInput = {self._gradInput, self._gradTarget} self.static = (static == nil) and true or static - + self:reset() end @@ -162,7 +162,7 @@ function SoftMaxTree:updateOutput(inputTable) self._multiBuffer:resize(input:size(1)*self.maxFamilyPath) self.batchSize = input:size(1) -- so that it works within nn.ConcatTable : - self._gradTarget:resizeAs(target):zero() + self._gradTarget:resizeAs(target):zero() if self._nodeUpdateHost then self._nodeUpdateHost:resize(input:size(1),self.maxDept) self._nodeUpdateCuda:resize(input:size(1),self.maxDept) @@ -281,7 +281,7 @@ function SoftMaxTree:type(type, typecache) if type == torch.type(self.weight) then return self end - + local hierarchy = self.hierarchy self.hierarchy = nil self._nodeUpdateHost = nil @@ -301,16 +301,16 @@ function SoftMaxTree:type(type, typecache) local parentIds = self.parentIds self.parentIds = nil self._gradOutput = nil - + parent.type(self, type, typecache) - + self.hierarchy = hierarchy self.parentChildren = parentChildren self.childParent = childParent self._gradTarget = _gradTarget self.childIds = childIds self.parentIds = parentIds - + if (type == 'torch.CudaTensor') then -- cunnx needs this for filling self.updates self._nodeUpdateHost = torch.IntTensor() @@ -327,7 +327,7 @@ function SoftMaxTree:type(type, typecache) self.childParent = self.childParent:type('torch.IntTensor') self._gradTarget = self._gradTarget:type('torch.IntTensor') end - self.gradInput = {self._gradInput, self._gradTarget} + self.gradInput = {self._gradInput, self._gradTarget} self.batchSize = 0 --so that buffers are resized return self end @@ -343,5 +343,30 @@ function SoftMaxTree:maxNorm(maxNorm) end end +function SoftMaxTree:momentumGradParameters() + -- get dense view of momGradParams + local _ = require 'moses' + if not self.momGradParams or _.isEmpty(self.momGradParams) then + assert(not self.accUpdate, "cannot use momentum with accUpdate") + self.momGradParams = {self.gradWeight:clone():zero(), self.gradBias:clone():zero()} + end + local momGradParams = self.momGradParams + if self.static and not _.isEmpty(self.updates) then + local momGradWeight = momGradParams[1] + local momGradBias = momGradParams[2] + momGradParams = {} + -- only return the parameters affected by the forward/backward + for parentId, scale in pairs(self.updates) do + local node = self.parentChildren:select(1, parentId) + local parentIdx = node[1] + local nChildren = node[2] + momGradParams[parentId] = momGradWeight:narrow(1, parentIdx, nChildren) + local biasId = parentId+self.maxParentId + momGradParams[biasId] = momGradBias:narrow(1, parentIdx, nChildren) + end + end + return momGradParams +end + -- we do not need to accumulate parameters when sharing SoftMaxTree.sharedAccUpdateGradParameters = SoftMaxTree.accUpdateGradParameters |