diff options
author | Guillaume Klein <guillaume.klein@systrangroup.com> | 2017-01-03 11:43:32 +0300 |
---|---|---|
committer | Guillaume Klein <guillaume.klein@systrangroup.com> | 2017-01-03 11:43:32 +0300 |
commit | e37c33d04eef3bcd7588eb85f3be580116b82f86 (patch) | |
tree | d906d4aae53628f9835dc050d392aafd7ff1f817 | |
parent | 422374f615e596e4d4418a7d07e49bde49668a27 (diff) |
Use conditional branching to call the function shared version (#1091)
This avoids overwriting the `accUpdateGradParameters` field which
exposes the function to serialization and makes the saved object
incompatible with other Lua versions.
-rw-r--r-- | Module.lua | 22 |
1 files changed, 13 insertions, 9 deletions
@@ -47,13 +47,17 @@ function Module:accGradParameters(input, gradOutput, scale) end function Module:accUpdateGradParameters(input, gradOutput, lr) - local gradWeight = self.gradWeight - local gradBias = self.gradBias - self.gradWeight = self.weight - self.gradBias = self.bias - self:accGradParameters(input, gradOutput, -lr) - self.gradWeight = gradWeight - self.gradBias = gradBias + if self.shared then + self:sharedAccUpdateGradParameters(input, gradOutput, lr) + else + local gradWeight = self.gradWeight + local gradBias = self.gradBias + self.gradWeight = self.weight + self.gradBias = self.bias + self:accGradParameters(input, gradOutput, -lr) + self.gradWeight = gradWeight + self.gradBias = gradBias + end end function Module:sharedAccUpdateGradParameters(input, gradOutput, lr) @@ -95,8 +99,8 @@ function Module:share(mlp, ...) for i,v in ipairs(arg) do if self[v] ~= nil then self[v]:set(mlp[v]) - self.accUpdateGradParameters = self.sharedAccUpdateGradParameters - mlp.accUpdateGradParameters = mlp.sharedAccUpdateGradParameters + self.shared = true + mlp.shared = true end end return self |