diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-06-11 23:05:28 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-06-11 23:05:28 +0400 |
commit | e70745e5b3cd4f2eb3fd15011aff2c7d891820a2 (patch) | |
tree | ed6d37a54aa1353f0361085fcd4b13f089fe9ec6 /SoftMaxTree.lua | |
parent | f9ecf8df49f1276a9e58ff99ff639a0d32d80512 (diff) |
added maxNorm argument to SoftMaxTree
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 9f3ba1e..07c07ef 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -6,9 +6,10 @@ local SoftMaxTree, parent = torch.class('nn.SoftMaxTree', 'nn.Module') -- Only works with a tree (one parent per child) ------------------------------------------------------------------------ -function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose) +function SoftMaxTree:__init(inputSize, hierarchy, rootId, maxNorm, verbose) parent.__init(self) self.rootId = rootId or 1 + self.maxNorm = maxNorm or 1 self.inputSize = inputSize assert(type(hierarchy) == 'table', "Expecting table at arg 2") -- get the total amount of children (non-root nodes) @@ -204,10 +205,18 @@ function SoftMaxTree:parameters(static) end function SoftMaxTree:updateParameters(learningRate, partial) + local maxNorm = self.maxNorm + if partial and nn.SoftMaxTree_updateParameters then + print"here" + os.exit() + end local params, gradParams = self:parameters(partial) if params then for k,param in pairs(params) do param:add(-learningRate, gradParams[k]) + if maxNorm then + param:renorm(2,1,maxNorm) + end end end end @@ -245,6 +254,8 @@ function SoftMaxTree:type(type) -- cunnx needs this for filling self.updates self._nodeUpdateHost = torch.IntTensor() self._nodeUpdateCuda = torch.CudaTensor() + self._paramUpdateHost = torch.IntTensor() + self._paramUpdateCuda = torch.CudaTensor() self.parentChildrenCuda = self.parentChildren:type(type) self.childParentCuda = self.childParent:type(type) elseif self.nodeUpdateHost then @@ -283,6 +294,7 @@ function SoftMaxTree:sharedClone() smt.maxDept = self.maxDept smt.gradWeight = self.gradWeight:clone() smt.gradBias = self.gradBias:clone() + smt.maxNorm = self.maxNorm return smt:share(self, 'weight', 'bias') end |