Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-06-11 23:05:28 +0400
committernicholas-leonard <nick@nikopia.org>2014-06-11 23:05:28 +0400
commite70745e5b3cd4f2eb3fd15011aff2c7d891820a2 (patch)
treeed6d37a54aa1353f0361085fcd4b13f089fe9ec6 /SoftMaxTree.lua
parentf9ecf8df49f1276a9e58ff99ff639a0d32d80512 (diff)
added maxNorm argument to SoftMaxTree
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua14
1 files changed, 13 insertions, 1 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 9f3ba1e..07c07ef 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -6,9 +6,10 @@ local SoftMaxTree, parent = torch.class('nn.SoftMaxTree', 'nn.Module')
-- Only works with a tree (one parent per child)
------------------------------------------------------------------------
-function SoftMaxTree:__init(inputSize, hierarchy, rootId, verbose)
+function SoftMaxTree:__init(inputSize, hierarchy, rootId, maxNorm, verbose)
parent.__init(self)
self.rootId = rootId or 1
+ self.maxNorm = maxNorm or 1
self.inputSize = inputSize
assert(type(hierarchy) == 'table', "Expecting table at arg 2")
-- get the total amount of children (non-root nodes)
@@ -204,10 +205,18 @@ function SoftMaxTree:parameters(static)
end
function SoftMaxTree:updateParameters(learningRate, partial)
+ local maxNorm = self.maxNorm
+ if partial and nn.SoftMaxTree_updateParameters then
+ print"here"
+ os.exit()
+ end
local params, gradParams = self:parameters(partial)
if params then
for k,param in pairs(params) do
param:add(-learningRate, gradParams[k])
+ if maxNorm then
+ param:renorm(2,1,maxNorm)
+ end
end
end
end
@@ -245,6 +254,8 @@ function SoftMaxTree:type(type)
-- cunnx needs this for filling self.updates
self._nodeUpdateHost = torch.IntTensor()
self._nodeUpdateCuda = torch.CudaTensor()
+ self._paramUpdateHost = torch.IntTensor()
+ self._paramUpdateCuda = torch.CudaTensor()
self.parentChildrenCuda = self.parentChildren:type(type)
self.childParentCuda = self.childParent:type(type)
elseif self.nodeUpdateHost then
@@ -283,6 +294,7 @@ function SoftMaxTree:sharedClone()
smt.maxDept = self.maxDept
smt.gradWeight = self.gradWeight:clone()
smt.gradBias = self.gradBias:clone()
+ smt.maxNorm = self.maxNorm
return smt:share(self, 'weight', 'bias')
end