diff options
author | nicholas-leonard <nick@nikopia.org> | 2014-07-08 21:50:41 +0400 |
---|---|---|
committer | nicholas-leonard <nick@nikopia.org> | 2014-07-08 21:50:41 +0400 |
commit | 11e4e58f22700b8fb4ed033243cb74b0bc14dd53 (patch) | |
tree | 3df4ec1f5065aa2badc0c4a212608acc96ba110f /SoftMaxTree.lua | |
parent | f6fcccff55b16e1a1ac1d414f0a04f6efcb910d8 (diff) |
fixed accUpdate bug
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r-- | SoftMaxTree.lua | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua index 31a9a2e..5db6a10 100644 --- a/SoftMaxTree.lua +++ b/SoftMaxTree.lua @@ -246,8 +246,10 @@ function SoftMaxTree:type(type) if type and (type == 'torch.FloatTensor' or type == 'torch.DoubleTensor' or type == 'torch.CudaTensor') then self.weight = self.weight:type(type) self.bias = self.bias:type(type) - self.gradWeight = self.gradWeight:type(type) - self.gradBias = self.gradBias:type(type) + if not self.accUpdate then + self.gradWeight = self.gradWeight:type(type) + self.gradBias = self.gradBias:type(type) + end self._nodeBuffer = self._nodeBuffer:type(type) self._multiBuffer = self._multiBuffer:type(type) self.output = self.output:type(type) @@ -295,8 +297,10 @@ function SoftMaxTree:sharedClone() smt.childParent = self.childParent smt.maxFamilyPath = self.maxFamilyPath smt.maxDept = self.maxDept - smt.gradWeight = self.gradWeight:clone() - smt.gradBias = self.gradBias:clone() + if not self.accUpdate then + smt.gradWeight = self.gradWeight:clone() + smt.gradBias = self.gradBias:clone() + end if type == 'torch.CudaTensor' then smt.parentChildrenCuda = self.parentChildrenCuda smt.childParentCuda = self.childParentCuda @@ -305,10 +309,9 @@ function SoftMaxTree:sharedClone() end function SoftMaxTree:maxNorm(maxNorm, partial) - local params, gradParams = self:parameters(partial) + local params = self:parameters(partial) if params then for k,param in pairs(params) do - param:add(-learningRate, gradParams[k]) if param:dim() == 2 and maxNorm then param:renorm(2,1,maxNorm) end |