Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-06-04 01:38:56 +0400
committernicholas-leonard <nick@nikopia.org>2014-06-04 01:38:56 +0400
commit57d7316d80bc6611a7c6d4088283f3117dc0e569 (patch)
treefd7d8cb097f8c50841a10fce25c6e31c5413c505 /SoftMaxTree.lua
parentf2b1a60fad4ab4cb8c51287743837142e21aabd4 (diff)
fixed SoftMaxTree:updateOutput
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua13
1 files changed, 11 insertions, 2 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index ce5e2bf..fc2b55c 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -156,8 +156,8 @@ function SoftMaxTree:updateOutput(inputTable)
self._multiBuffer:resize(input:size(1)*self.maxFamilyPath)
self.batchSize = input:size(1)
if self._nodeUpdateHost then
- self._nodeUpdateHost:resize(input:size(1)*self.maxDept)
- self._nodeUpdateCuda:resize(input:size(1)*self.maxDept)
+ self._nodeUpdateHost:resize(input:size(1),self.maxDept)
+ self._nodeUpdateCuda:resize(input:size(1),self.maxDept)
end
end
return input.nn.SoftMaxTree_updateOutput(self, input, target)
@@ -205,6 +205,15 @@ function SoftMaxTree:parameters(static)
return params, grads
end
+function SoftMaxTree:updateParameters(learningRate, partial)
+ local params, gradParams = self:parameters(partial)
+ if params then
+ for k,param in pairs(params) do
+ param:add(-learningRate, gradParams[k])
+ end
+ end
+end
+
function SoftMaxTree:getNodeParameters(parentId)
local node = self.parentChildren:select(1,parentId)
local start = node[1]