diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-16 07:48:38 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-16 07:48:38 +0400 |
commit | c5b89c15c968ea830ab9908744d37bd54b4cee41 (patch) | |
tree | afc06c875e1d5fe5165a53089e90be59185c0420 /SoftMaxTree.lua | |
parent | 4437925e8d3369ebbb8ee3dc3dcaf21aa29bdef2 (diff) |
SoftMaxTree.static
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index f37396e..25eb4e5 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -6,7 +6,7 @@ local SoftMaxTree, parent = torch.class('nn.SoftMaxTree', 'nn.Module') -- Only works with a tree (one parent per child) ------------------------------------------------------------------------ -function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, verbose) +function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, static, verbose) parent.__init(self) self.rootId = rootId or 1 self.inputSize = inputSize @@ -139,6 +139,7 @@ function SoftMaxTree:__init(inputSize, hierarchy, rootId, accUpdate, verbose) self._gradInput = torch.Tensor() self._gradTarget = torch.IntTensor() -- dummy self.gradInput = {self._gradInput, self._gradTarget} + self.static = (static == nil) and true or static self:reset() end @@ -186,7 +187,8 @@ end -- when static is true, return parameters with static keys -- i.e. keys that don't change from batch to batch -function SoftMaxTree:parameters(static) +function SoftMaxTree:parameters() + static = self.static local params, grads = {}, {} local updated = false for parentId, scale in pairs(self.updates) do @@ -233,9 +235,9 @@ function SoftMaxTree:parameters(static) return params, grads end -function SoftMaxTree:updateParameters(learningRate, partial) +function SoftMaxTree:updateParameters(learningRate) assert(not self.accUpdate) - local params, gradParams = self:parameters(partial) + local params, gradParams = self:parameters() if params then for k,param in pairs(params) do param:add(-learningRate, gradParams[k]) @@ -257,8 +259,8 @@ function SoftMaxTree:getNodeParameters(parentId) return {weight, bias} end -function SoftMaxTree:zeroGradParameters(partial) - local _,gradParams = self:parameters(partial) +function SoftMaxTree:zeroGradParameters() + local _,gradParams = self:parameters() for k,gradParam in pairs(gradParams) do gradParam:zero() end @@ -332,8 +334,8 @@ function SoftMaxTree:sharedClone() return smt:share(self, 'weight', 'bias') end -function SoftMaxTree:maxNorm(maxNorm, partial) - local params = self:parameters(partial) +function SoftMaxTree:maxNorm(maxNorm) + local params = self:parameters() if params then for k,param in pairs(params) do if param:dim() == 2 and maxNorm then |