diff options
author | Jonas Gehring <jgehring@fb.com> | 2017-05-04 17:54:36 +0300 |
---|---|---|
committer | Jonas Gehring <jgehring@fb.com> | 2017-05-04 17:54:42 +0300 |
commit | ad8389a758714af7e866e1db0ab830c8a9da31ee (patch) | |
tree | 1c9e4b8fb6a918bc4f6c034e73263372f59ccad9 /WeightNorm.lua | |
parent | 72c47ce4610bec72e9b8c5ede715477f61d893a1 (diff) |
Fix WeightNorm serialization for permutated weight matrices
For layers where the first weight dimension does not correspond to the output
dimension, self.viewOut will *not* correspond to self.weight:size(). This is
fixed by introducing another member variable that simply holds the original size
of the weight matrix.
Diffstat (limited to 'WeightNorm.lua')
-rw-r--r-- | WeightNorm.lua | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/WeightNorm.lua b/WeightNorm.lua index 3bbbea8..14282fd 100644 --- a/WeightNorm.lua +++ b/WeightNorm.lua @@ -34,6 +34,7 @@ function WeightNorm:__init(module, outputDim) -- view size back to original weight self.viewOut = self.weight:size() + self.weightSize = self.weight:size() -- bubble outputDim size up to the front for i = self.outputDim - 1, 1, -1 do @@ -186,6 +187,9 @@ function WeightNorm:write(file) self.gradWeight = nil self.modules[1].weight = nil self.modules[1].gradWeight = nil + if not self.weightSize then + self.weightSize = weight:size() + end parent.write(self, file) @@ -199,10 +203,12 @@ function WeightNorm:read(file) parent.read(self, file) -- Re-compute weight and gradWeight - self.modules[1].weight = self.v.new(self.viewOut) - self.modules[1].gradWeight = self.v.new(self.viewOut) - self.weight = self.modules[1].weight - self.gradWeight = self.modules[1].gradWeight - self:updateWeight() - self.gradWeight:copy(self:permuteOut(self.gradV)) + if not self.weight then + self.modules[1].weight = self.v.new(self.weightSize) + self.modules[1].gradWeight = self.v.new(self.weightSize) + self.weight = self.modules[1].weight + self.gradWeight = self.modules[1].gradWeight + self:updateWeight() + self.gradWeight:copy(self:permuteOut(self.gradV)) + end end |