Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2014-07-08 21:50:41 +0400
committernicholas-leonard <nick@nikopia.org>2014-07-08 21:50:41 +0400
commit11e4e58f22700b8fb4ed033243cb74b0bc14dd53 (patch)
tree3df4ec1f5065aa2badc0c4a212608acc96ba110f /SoftMaxTree.lua
parentf6fcccff55b16e1a1ac1d414f0a04f6efcb910d8 (diff)
fixed accUpdate bug
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua15
1 files changed, 9 insertions, 6 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index 31a9a2e..5db6a10 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -246,8 +246,10 @@ function SoftMaxTree:type(type)
if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor' or type == 'torch.CudaTensor') then
self.weight = self.weight:type(type)
self.bias = self.bias:type(type)
- self.gradWeight = self.gradWeight:type(type)
- self.gradBias = self.gradBias:type(type)
+ if not self.accUpdate then
+ self.gradWeight = self.gradWeight:type(type)
+ self.gradBias = self.gradBias:type(type)
+ end
self._nodeBuffer = self._nodeBuffer:type(type)
self._multiBuffer = self._multiBuffer:type(type)
self.output = self.output:type(type)
@@ -295,8 +297,10 @@ function SoftMaxTree:sharedClone()
smt.childParent = self.childParent
smt.maxFamilyPath = self.maxFamilyPath
smt.maxDept = self.maxDept
- smt.gradWeight = self.gradWeight:clone()
- smt.gradBias = self.gradBias:clone()
+ if not self.accUpdate then
+ smt.gradWeight = self.gradWeight:clone()
+ smt.gradBias = self.gradBias:clone()
+ end
if type == 'torch.CudaTensor' then
smt.parentChildrenCuda = self.parentChildrenCuda
smt.childParentCuda = self.childParentCuda
@@ -305,10 +309,9 @@ function SoftMaxTree:sharedClone()
end
function SoftMaxTree:maxNorm(maxNorm, partial)
- local params, gradParams = self:parameters(partial)
+ local params = self:parameters(partial)
if params then
for k,param in pairs(params) do
- param:add(-learningRate, gradParams[k])
if param:dim() == 2 and maxNorm then
param:renorm(2,1,maxNorm)
end