Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2015-04-20 18:47:58 +0300
committernicholas-leonard <nick@nikopia.org>2015-04-20 18:47:58 +0300
commita9bab84732c4cf1fe5c50ae0272f81b7275f3306 (patch)
treee27416ffff49409353f8fa08c085fbb01fb235d8 /SoftMaxTree.lua
parent35cfc17ee40ceae4916b1f1754fbdb5eea10c5aa (diff)
parent3642beecfba9bee13e4ee1e109f7ef6da41a0452 (diff)
Merge branch 'master' of github.com:clementfarabet/lua---nnx
Diffstat (limited to 'SoftMaxTree.lua')
-rw-r--r--SoftMaxTree.lua11
1 files changed, 8 insertions, 3 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index af4be64..6633359 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -264,7 +264,11 @@ function SoftMaxTree:zeroGradParameters()
for k,gradParam in pairs(gradParams) do
gradParam:zero()
end
- self.updates = {}
+ -- loop is used instead of 'self.updates = {}'
+ -- to handle the case when updates are shared
+ for k,v in pairs(self.updates) do
+ self.updates[k] = nil
+ end
end
function SoftMaxTree:type(type)
@@ -325,9 +329,10 @@ function SoftMaxTree:sharedClone()
smt.childParent = self.childParent
smt.maxFamilyPath = self.maxFamilyPath
smt.maxDept = self.maxDept
+ smt.updates = self.updates
if not self.accUpdate then
- smt.gradWeight = self.gradWeight:clone()
- smt.gradBias = self.gradBias:clone()
+ smt.gradWeight = self.gradWeight
+ smt.gradBias = self.gradBias
end
if type == 'torch.CudaTensor' then
smt.parentChildrenCuda = self.parentChildrenCuda