Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRichard Assar <richard.assar@gmail.com>2017-03-07 02:58:09 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-11 21:50:25 +0300
commit65f42723f51d13683a431329452e5f385988e984 (patch)
tree44dcda0026791ffdb085d6602a5a0dd91d5020b1
parent5f9f4f87c896cd1a8bf0999c49c892e368368ad7 (diff)
Adding LinearWeightNorm layer, test and updated documentation
-rw-r--r--Linear.lua6
-rwxr-xr-xLinearWeightNorm.lua168
-rwxr-xr-x[-rw-r--r--]doc/containers.md7
-rwxr-xr-x[-rw-r--r--]doc/simple.md14
-rwxr-xr-x[-rw-r--r--]init.lua1
-rwxr-xr-xtest.lua48
6 files changed, 239 insertions, 5 deletions
diff --git a/Linear.lua b/Linear.lua
index 3221227..09b5979 100644
--- a/Linear.lua
+++ b/Linear.lua
@@ -42,7 +42,7 @@ function Linear:reset(stdv)
return self
end
-local function updateAddBuffer(self, input)
+function Linear:updateAddBuffer(input)
local nframe = input:size(1)
self.addBuffer = self.addBuffer or input.new()
if self.addBuffer:nElement() ~= nframe then
@@ -62,7 +62,7 @@ function Linear:updateOutput(input)
if self.output:nElement() ~= nElement then
self.output:zero()
end
- updateAddBuffer(self, input)
+ self:updateAddBuffer(input)
self.output:addmm(0, self.output, 1, input, self.weight:t())
if self.bias then self.output:addr(1, self.addBuffer, self.bias) end
else
@@ -99,7 +99,7 @@ function Linear:accGradParameters(input, gradOutput, scale)
self.gradWeight:addmm(scale, gradOutput:t(), input)
if self.bias then
-- update the size of addBuffer if the input is not the same size as the one we had in last updateGradInput
- updateAddBuffer(self, input)
+ self:updateAddBuffer(input)
self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer)
end
end
diff --git a/LinearWeightNorm.lua b/LinearWeightNorm.lua
new file mode 100755
index 0000000..a712f55
--- /dev/null
+++ b/LinearWeightNorm.lua
@@ -0,0 +1,168 @@
+local LinearWeightNorm, parent = torch.class('nn.LinearWeightNorm', 'nn.Linear')
+
+function LinearWeightNorm:__init(inputSize, outputSize, bias, eps)
+ nn.Module.__init(self) -- Skip nn.Linear constructor
+
+ local bias = ((bias == nil) and true) or bias
+
+ self.eps = eps or 1e-16
+
+ self.outputSize = outputSize
+ self.inputSize = inputSize
+
+ self.v = torch.Tensor(outputSize, inputSize)
+ self.gradV = torch.Tensor(outputSize, inputSize)
+
+ self.weight = torch.Tensor(outputSize, inputSize)
+
+ self.g = torch.Tensor(outputSize,1)
+ self.gradG = torch.Tensor(outputSize,1)
+
+ self.norm = torch.Tensor(outputSize,1)
+ self.scale = torch.Tensor(outputSize,1)
+
+ if bias then
+ self.bias = torch.Tensor(outputSize)
+ self.gradBias = torch.Tensor(outputSize)
+ end
+
+ self:reset()
+end
+
+function LinearWeightNorm:evaluate()
+ if self.train ~= false then
+ self:updateWeightMatrix()
+ end
+
+ parent.evaluate(self)
+end
+
+function LinearWeightNorm:initFromWeight(weight)
+ weight = weight or self.weight
+
+ self.g:norm(weight,2,2):clamp(self.eps,math.huge)
+ self.v:copy(weight)
+
+ return self
+end
+
+function LinearWeightNorm.fromLinear(linear)
+ local module = nn.LinearWeightNorm(linear.weight:size(2), linear.weight:size(1), torch.isTensor(linear.bias))
+ module.weight:copy(linear.weight)
+ module:initFromWeight()
+
+ if linear.bias then
+ module.bias:copy(linear.bias)
+ end
+
+ return module
+end
+
+function LinearWeightNorm:toLinear()
+ self:updateWeightMatrix()
+
+ local module = nn.Linear(self.inputSize, self.outputSize, torch.isTensor(self.bias))
+
+ module.weight:copy(self.weight)
+ if self.bias then
+ module.bias:copy(self.bias)
+ end
+
+ return module
+end
+
+function LinearWeightNorm:parameters()
+ if self.bias then
+ return {self.v, self.g, self.bias}, {self.gradV, self.gradG, self.gradBias}
+ else
+ return {self.v, self.g}, {self.gradV, self.gradG}
+ end
+end
+
+function LinearWeightNorm:reset(stdv)
+ if stdv then
+ stdv = stdv * math.sqrt(3)
+ else
+ stdv = 1 / math.sqrt(self.inputSize)
+ end
+
+ self.weight:uniform(-stdv,stdv)
+ self:initFromWeight()
+
+ if self.bias then
+ self.bias:uniform(-stdv,stdv)
+ end
+end
+
+function LinearWeightNorm:updateWeightMatrix()
+ if self.norm:dim() == 0 then self.norm:resizeAs(self.g) end
+ if self.scale:dim() == 0 then self.scale:resizeAs(self.g) end
+ if self.weight:dim() == 0 then self.weight:resizeAs(self.v) end
+
+ self.norm:norm(self.v,2,2):clamp(self.eps,math.huge)
+ self.scale:cdiv(self.g,self.norm)
+ self.weight:cmul(self.v,self.scale:expandAs(self.v))
+end
+
+function LinearWeightNorm:updateOutput(input)
+ if self.train ~= false then
+ self:updateWeightMatrix()
+ end
+
+ return parent.updateOutput(self, input)
+end
+
+function LinearWeightNorm:accGradParameters(input, gradOutput, scale)
+ scale = scale or 1
+ if input:dim() == 1 then
+ self.gradV:addr(scale, gradOutput, input)
+ if self.bias then self.gradBias:add(scale, gradOutput) end
+ elseif input:dim() == 2 then
+ self.gradV:addmm(scale, gradOutput:t(), input)
+ if self.bias then
+ -- update the size of addBuffer if the input is not the same size as the one we had in last updateGradInput
+ self:updateAddBuffer(input)
+ self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer)
+ end
+ end
+
+ local scale = self.scale:expandAs(self.v)
+ local norm = self.norm:expandAs(self.v)
+
+ self.weight:cmul(self.gradV,self.v):cdiv(norm)
+ self.gradG:sum(self.weight,2)
+
+ self.gradV:cmul(scale)
+
+ self.weight:cmul(self.v,scale):cdiv(norm)
+ self.weight:cmul(self.gradG:expandAs(self.weight))
+
+ self.gradV:add(-1,self.weight)
+end
+
+function LinearWeightNorm:defaultAccUpdateGradParameters(input, gradOutput, lr)
+ local gradV = self.gradV
+ local gradG = self.gradG
+ local gradBias = self.gradBias
+
+ self.gradV = self.v
+ self.gradG = self.g
+ self.gradBias = self.bias
+
+ self:accGradParameters(input, gradOutput, -lr)
+
+ self.gradV = gradV
+ self.gradG = gradG
+ self.gradBias = gradBias
+end
+
+function LinearWeightNorm:clearState()
+ nn.utils.clear(self, 'weight', 'norm', 'scale')
+ return parent.clearState(self)
+end
+
+function LinearWeightNorm:__tostring__()
+ return torch.type(self) ..
+ string.format('(%d -> %d)', self.inputSize, self.outputSize) ..
+ (self.bias == nil and ' without bias' or '')
+end \ No newline at end of file
diff --git a/doc/containers.md b/doc/containers.md
index 5e404e3..cecf782 100644..100755
--- a/doc/containers.md
+++ b/doc/containers.md
@@ -346,9 +346,12 @@ mlp = nn.Bottle(nn.Linear(10, 2))
module = nn.WeightNorm(module)
```
-WeightNorm implements the reparametrization presented in [Weight Normalization](https://arxiv.org/pdf/1602.07868v3.pdf), which decouples the length of neural network weight vectors from their direction. The weight vectors `w` is determined instead by parameters `g` and `v` such that `w = g * v / ||v||`, where `||v||` is the euclidean norm of vector v. This container can wrap nn layers with weights.
+WeightNorm implements the reparametrization presented in [Weight Normalization](https://arxiv.org/pdf/1602.07868v3.pdf), which decouples the length of neural network weight vectors from their direction. The weight vector `w` is determined instead by parameters `g` and `v` such that `w = g * v / ||v||`, where `||v||` is the euclidean norm of vector `v`. This container can wrap nn layers with weights.
+
+It accepts a parameter ``outputDim`` that represents the output dimension of the module weight it wraps, which defaults to 1. If the outputDim is not 1 the container will transpose the weight appropriately. If the module weight is not 2D, e.g. in the case of convolutional layers, the container will view the weight into an appropriate 2D shape based on the `outputDim` specified by the user.
+
+An optimised version of `nn.WeightNorm(nn.Linear(inputDimension, outputDimension))` is available as `nn.LinearWeightNorm(inputDimension, outputDimension, [bias = true])`. This layer occupies less memory and is faster through the use of fewer tensor copy operations, it also stores and updates a dirty flag to avoid unnecessary computation of the weight matrix.
-It accepts a parameter ``outputDim`` that represents the output dimension of the module weight it wraps, which defaults to 1. If the outputDim is not 1, the container will transpose the weight appropriately. If the module weight is not 2D, the container will view the weight into an appropriate 2D shape based on the outputDim specified by the user.
<a name='nn.DontCast'></a>
## DontCast ##
diff --git a/doc/simple.md b/doc/simple.md
index 9306edf..2be635c 100644..100755
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -4,6 +4,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* Parameterized Modules :
* [Linear](#nn.Linear) : a linear transformation ;
+ * [LinearWeightNorm](#nn.LinearWeightNorm) : a weight normalized linear transformation ;
* [SparseLinear](#nn.SparseLinear) : a linear transformation with sparse inputs ;
* [IndexLinear](#nn.IndexLinear) : an alternative linear transformation with for sparse inputs and max normalization ;
* [Bilinear](#nn.Bilinear) : a bilinear transformation with sparse inputs ;
@@ -100,6 +101,19 @@ x = torch.Tensor(10) -- 10 inputs
y = module:forward(x)
```
+<a name="nn.LinearWeightNorm"></a>
+## LinearWeightNorm ##
+
+```lua
+module = nn.LinearWeightNorm(inputDimension, outputDimension, [bias = true])
+```
+
+LinearWeightNorm implements the reparametrization presented in [Weight Normalization](https://arxiv.org/pdf/1602.07868v3.pdf), which decouples the length of neural network weight vectors from their direction. The weight vector `w` is determined instead by parameters `g` and `v` such that `w = g * v / ||v||`, where `||v||` is the euclidean norm of vector `v`. In all other respects this layer behaves like `nn.Linear`.
+
+To convert between `nn.Linear` and `nn.LinearWeightNorm` you can use the `nn.LinearWeightNorm.fromLinear(linearModule)` and `weightNormModule:toLinear()` functions.
+
+Other layer types can make use of weight normalization through the [nn.WeightNorm](https://github.com/torch/nn/blob/master/doc/containers.md#nn.WeightNorm) container.
+
<a name="nn.SparseLinear"></a>
## SparseLinear ##
diff --git a/init.lua b/init.lua
index e134b27..fb6926f 100644..100755
--- a/init.lua
+++ b/init.lua
@@ -24,6 +24,7 @@ require('nn.NaN')
require('nn.Profile')
require('nn.Linear')
+require('nn.LinearWeightNorm')
require('nn.Bilinear')
require('nn.PartialLinear')
require('nn.SparseLinear')
diff --git a/test.lua b/test.lua
index 125d1d1..9eaf21e 100755
--- a/test.lua
+++ b/test.lua
@@ -177,6 +177,54 @@ function nntest.WeightNorm()
mytester:assertTensorEq(out, outr)
end
+function nntest.LinearWeightNorm()
+ local input = torch.rand(10, 5)
+ local model = nn.LinearWeightNorm(5, 20)
+
+ -- check gradient
+ local err = nn.Jacobian.testJacobianParameters(model, input, model.bias, model.gradBias)
+ mytester:assert(err < precision, 'bias')
+ err = nn.Jacobian.testJacobianParameters(model, input, model.g, model.gradG)
+ mytester:assert(err < precision, 'g')
+ err = nn.Jacobian.testJacobianParameters(model, input, model.v, model.gradV)
+ mytester:assert(err < precision, 'v')
+
+ -- check conversion functions
+ local linear = nn.Linear(5,20)
+ local wnFromLin = nn.LinearWeightNorm.fromLinear(linear)
+ local linFromWn = wnFromLin:toLinear()
+
+ local linOut = linear:forward(input)
+ local wnOut = wnFromLin:forward(input)
+ local linFromWnOut = linFromWn:forward(input)
+
+ mytester:assertTensorEq(linOut, wnOut, precision, "outputs are not equivalent")
+ mytester:assertTensorEq(wnOut, linFromWnOut, precision, "outputs are not equivalent")
+
+ -- check conversion with nobias
+ linear = nn.Linear(5,20,false)
+ wnFromLin = nn.LinearWeightNorm.fromLinear(linear)
+ linFromWn = wnFromLin:toLinear()
+
+ linOut = linear:forward(input)
+ wnOut = wnFromLin:forward(input)
+ linFromWnOut = linFromWn:forward(input)
+
+ mytester:assertTensorEq(linear.weight, wnFromLin.weight, precision, "weights are not equivalent")
+ mytester:assert(not wnFromLin.bias)
+ mytester:assert(not linear.bias)
+ mytester:assertTensorEq(linOut, wnOut, precision, "outputs are not equivalent")
+ mytester:assertTensorEq(wnOut, linFromWnOut, precision, "outputs are not equivalent")
+
+ -- check gradient with nobias
+ model = wnFromLin
+
+ err = nn.Jacobian.testJacobianParameters(model, input, model.g, model.gradG)
+ mytester:assert(err < precision, 'g')
+ err = nn.Jacobian.testJacobianParameters(model, input, model.v, model.gradV)
+ mytester:assert(err < precision, 'v')
+end
+
function nntest.CAdd()
local function testBackwardPass(module, input, params, dparams)
local err = jac.testJacobian(module,input)