Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGuillaume Klein <guillaume.klein@systrangroup.com>2017-01-03 11:43:32 +0300
committerGuillaume Klein <guillaume.klein@systrangroup.com>2017-01-03 11:43:32 +0300
commite37c33d04eef3bcd7588eb85f3be580116b82f86 (patch)
treed906d4aae53628f9835dc050d392aafd7ff1f817
parent422374f615e596e4d4418a7d07e49bde49668a27 (diff)
Use conditional branching to call the function shared version (#1091)
This avoids overwriting the `accUpdateGradParameters` field which exposes the function to serialization and makes the saved object incompatible with other Lua versions.
-rw-r--r--Module.lua22
1 files changed, 13 insertions, 9 deletions
diff --git a/Module.lua b/Module.lua
index c1a0328..f7fd08f 100644
--- a/Module.lua
+++ b/Module.lua
@@ -47,13 +47,17 @@ function Module:accGradParameters(input, gradOutput, scale)
end
function Module:accUpdateGradParameters(input, gradOutput, lr)
- local gradWeight = self.gradWeight
- local gradBias = self.gradBias
- self.gradWeight = self.weight
- self.gradBias = self.bias
- self:accGradParameters(input, gradOutput, -lr)
- self.gradWeight = gradWeight
- self.gradBias = gradBias
+ if self.shared then
+ self:sharedAccUpdateGradParameters(input, gradOutput, lr)
+ else
+ local gradWeight = self.gradWeight
+ local gradBias = self.gradBias
+ self.gradWeight = self.weight
+ self.gradBias = self.bias
+ self:accGradParameters(input, gradOutput, -lr)
+ self.gradWeight = gradWeight
+ self.gradBias = gradBias
+ end
end
function Module:sharedAccUpdateGradParameters(input, gradOutput, lr)
@@ -95,8 +99,8 @@ function Module:share(mlp, ...)
for i,v in ipairs(arg) do
if self[v] ~= nil then
self[v]:set(mlp[v])
- self.accUpdateGradParameters = self.sharedAccUpdateGradParameters
- mlp.accUpdateGradParameters = mlp.sharedAccUpdateGradParameters
+ self.shared = true
+ mlp.shared = true
end
end
return self