Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonas Gehring <jgehring@fb.com>2017-05-04 17:54:36 +0300
committerJonas Gehring <jgehring@fb.com>2017-05-04 17:54:42 +0300
commitad8389a758714af7e866e1db0ab830c8a9da31ee (patch)
tree1c9e4b8fb6a918bc4f6c034e73263372f59ccad9 /WeightNorm.lua
parent72c47ce4610bec72e9b8c5ede715477f61d893a1 (diff)
Fix WeightNorm serialization for permutated weight matrices
For layers where the first weight dimension does not correspond to the output dimension, self.viewOut will *not* correspond to self.weight:size(). This is fixed by introducing another member variable that simply holds the original size of the weight matrix.
Diffstat (limited to 'WeightNorm.lua')
-rw-r--r--WeightNorm.lua18
1 files changed, 12 insertions, 6 deletions
diff --git a/WeightNorm.lua b/WeightNorm.lua
index 3bbbea8..14282fd 100644
--- a/WeightNorm.lua
+++ b/WeightNorm.lua
@@ -34,6 +34,7 @@ function WeightNorm:__init(module, outputDim)
-- view size back to original weight
self.viewOut = self.weight:size()
+ self.weightSize = self.weight:size()
-- bubble outputDim size up to the front
for i = self.outputDim - 1, 1, -1 do
@@ -186,6 +187,9 @@ function WeightNorm:write(file)
self.gradWeight = nil
self.modules[1].weight = nil
self.modules[1].gradWeight = nil
+ if not self.weightSize then
+ self.weightSize = weight:size()
+ end
parent.write(self, file)
@@ -199,10 +203,12 @@ function WeightNorm:read(file)
parent.read(self, file)
-- Re-compute weight and gradWeight
- self.modules[1].weight = self.v.new(self.viewOut)
- self.modules[1].gradWeight = self.v.new(self.viewOut)
- self.weight = self.modules[1].weight
- self.gradWeight = self.modules[1].gradWeight
- self:updateWeight()
- self.gradWeight:copy(self:permuteOut(self.gradV))
+ if not self.weight then
+ self.modules[1].weight = self.v.new(self.weightSize)
+ self.modules[1].gradWeight = self.v.new(self.weightSize)
+ self.weight = self.modules[1].weight
+ self.gradWeight = self.modules[1].gradWeight
+ self:updateWeight()
+ self.gradWeight:copy(self:permuteOut(self.gradV))
+ end
end