Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-16 07:48:38 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-16 07:48:38 +0400
commitc5b89c15c968ea830ab9908744d37bd54b4cee41 (patch)
treeafc06c875e1d5fe5165a53089e90be59185c0420 /SoftMaxTree.lua
parent4437925e8d3369ebbb8ee3dc3dcaf21aa29bdef2 (diff)
SoftMaxTree.static
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua18
1 files changed, 10 insertions, 8 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index f37396e..25eb4e5 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -6,7 +6,7 @@ local SoftMaxTree, parent = torch.class('nn.SoftMaxTree', 'nn.Module')
-- Only works with a tree (one parent per child)
------------------------------------------------------------------------
-function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, verbose)
+function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, verbose)
parent.__init(self)
self.rootId = rootId or 1
self.inputSize = inputSize
@@ -139,6 +139,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, verbose)
self._gradInput = torch.Tensor()
self._gradTarget = torch.IntTensor() -- dummy
self.gradInput = {self._gradInput, self._gradTarget}
+ self.static = (static == nil) and true or static
self:reset()
end
@@ -186,7 +187,8 @@ end
-- when static is true, return parameters with static keys
-- i.e. keys that don't change from batch to batch
-function SoftMaxTree:parameters(static)
+function SoftMaxTree:parameters()
+ static = self.static
local params, grads = {}, {}
local updated = false
for parentId, scale in pairs(self.updates) do
@@ -233,9 +235,9 @@ function SoftMaxTree:parameters(static)
return params, grads
end
-function SoftMaxTree:updateParameters(learningRate, partial)
+function SoftMaxTree:updateParameters(learningRate)
assert(not self.accUpdate)
- local params, gradParams = self:parameters(partial)
+ local params, gradParams = self:parameters()
if params then
for k,param in pairs(params) do
param:add(-learningRate, gradParams[k])
@@ -257,8 +259,8 @@ function SoftMaxTree:getNodeParameters(parentId)
return {weight, bias}
end
-function SoftMaxTree:zeroGradParameters(partial)
- local _,gradParams = self:parameters(partial)
+function SoftMaxTree:zeroGradParameters()
+ local _,gradParams = self:parameters()
for k,gradParam in pairs(gradParams) do
gradParam:zero()
end
@@ -332,8 +334,8 @@ function SoftMaxTree:sharedClone()
return smt:share(self, 'weight', 'bias')
end
-function SoftMaxTree:maxNorm(maxNorm, partial)
- local params = self:parameters(partial)
+function SoftMaxTree:maxNorm(maxNorm)
+ local params = self:parameters()
if params then
for k,param in pairs(params) do
if param:dim() == 2 and maxNorm then