diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-06-04 01:38:56 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-06-04 01:38:56 +0400 |
commit | 57d7316d80bc6611a7c6d4088283f3117dc0e569 (patch) | |
tree | fd7d8cb097f8c50841a10fce25c6e31c5413c505 /SoftMaxTree.lua | |
parent | f2b1a60fad4ab4cb8c51287743837142e21aabd4 (diff) |
fixed SoftMaxTree:updateOutput
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index ce5e2bf..fc2b55c 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -156,8 +156,8 @@ function SoftMaxTree:updateOutput(inputTable) self._multiBuffer:resize(input:size(1)*self.maxFamilyPath) self.batchSize = input:size(1) if self._nodeUpdateHost then - self._nodeUpdateHost:resize(input:size(1)*self.maxDept) - self._nodeUpdateCuda:resize(input:size(1)*self.maxDept) + self._nodeUpdateHost:resize(input:size(1),self.maxDept) + self._nodeUpdateCuda:resize(input:size(1),self.maxDept) end end return input.nn.SoftMaxTree_updateOutput(self, input, target) @@ -205,6 +205,15 @@ function SoftMaxTree:parameters(static) return params, grads end +function SoftMaxTree:updateParameters(learningRate, partial) + local params, gradParams = self:parameters(partial) + if params then + for k,param in pairs(params) do + param:add(-learningRate, gradParams[k]) + end + end +end + function SoftMaxTree:getNodeParameters(parentId) local node = self.parentChildren:select(1,parentId) local start = node[1] |