Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/clementfarabet/lua---nnx.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Léonard <nick@nikopia.org>2015-04-09 17:26:56 +0300
committerNicholas Léonard <nick@nikopia.org>2015-04-09 17:26:56 +0300
commit3642beecfba9bee13e4ee1e109f7ef6da41a0452 (patch)
tree6ec3d812bd47efe335646af9415082c6e4e9bf81
parentb10c3aa2b51b3c6c848c4d8d2c2517ac59872269 (diff)
parent15a952f45418d4c581351473981ec7079481cdb5 (diff)
Merge pull request #36 from akhti/master
Fix SoftMaxTree:sharedClone
-rw-r--r--SoftMaxTree.lua11
1 files changed, 8 insertions, 3 deletions
diff --git a/SoftMaxTree.lua b/SoftMaxTree.lua
index e3ff61e..b7bfcbc 100644
--- a/SoftMaxTree.lua
+++ b/SoftMaxTree.lua
@@ -264,7 +264,11 @@ function SoftMaxTree:zeroGradParameters()
for k,gradParam in pairs(gradParams) do
gradParam:zero()
end
- self.updates = {}
+ -- loop is used instead of 'self.updates = {}'
+ -- to handle the case when updates are shared
+ for k,v in pairs(self.updates) do
+ self.updates[k] = nil
+ end
end
function SoftMaxTree:type(type)
@@ -325,9 +329,10 @@ function SoftMaxTree:sharedClone()
smt.childParent = self.childParent
smt.maxFamilyPath = self.maxFamilyPath
smt.maxDept = self.maxDept
+ smt.updates = self.updates
if not self.accUpdate then
- smt.gradWeight = self.gradWeight:clone()
- smt.gradBias = self.gradBias:clone()
+ smt.gradWeight = self.gradWeight
+ smt.gradBias = self.gradBias
end
if type == 'torch.CudaTensor' then
smt.parentChildrenCuda = self.parentChildrenCuda