diff options
author | Richard Assar <richard.assar@gmail.com> | 2017-03-07 02:58:09 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-11 21:50:25 +0300 |
commit | 65f42723f51d13683a431329452e5f385988e984 (patch) | |
tree | 44dcda0026791ffdb085d6602a5a0dd91d5020b1 | |
parent | 5f9f4f87c896cd1a8bf0999c49c892e368368ad7 (diff) |
Adding LinearWeightNorm layer, test and updated documentation
-rw-r--r-- | Linear.lua | 6 | ||||
-rwxr-xr-x | LinearWeightNorm.lua | 168 | ||||
-rwxr-xr-x[-rw-r--r--] | doc/containers.md | 7 | ||||
-rwxr-xr-x[-rw-r--r--] | doc/simple.md | 14 | ||||
-rwxr-xr-x[-rw-r--r--] | init.lua | 1 | ||||
-rwxr-xr-x | test.lua | 48 |
6 files changed, 239 insertions, 5 deletions
@@ -42,7 +42,7 @@ function Linear:reset(stdv) return self end -local function updateAddBuffer(self, input) +function Linear:updateAddBuffer(input) local nframe = input:size(1) self.addBuffer = self.addBuffer or input.new() if self.addBuffer:nElement() ~= nframe then @@ -62,7 +62,7 @@ function Linear:updateOutput(input) if self.output:nElement() ~= nElement then self.output:zero() end - updateAddBuffer(self, input) + self:updateAddBuffer(input) self.output:addmm(0, self.output, 1, input, self.weight:t()) if self.bias then self.output:addr(1, self.addBuffer, self.bias) end else @@ -99,7 +99,7 @@ function Linear:accGradParameters(input, gradOutput, scale) self.gradWeight:addmm(scale, gradOutput:t(), input) if self.bias then -- update the size of addBuffer if the input is not the same size as the one we had in last updateGradInput - updateAddBuffer(self, input) + self:updateAddBuffer(input) self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) end end diff --git a/LinearWeightNorm.lua b/LinearWeightNorm.lua new file mode 100755 index 0000000..a712f55 --- /dev/null +++ b/LinearWeightNorm.lua @@ -0,0 +1,168 @@ +local LinearWeightNorm, parent = torch.class('nn.LinearWeightNorm', 'nn.Linear') + +function LinearWeightNorm:__init(inputSize, outputSize, bias, eps) + nn.Module.__init(self) -- Skip nn.Linear constructor + + local bias = ((bias == nil) and true) or bias + + self.eps = eps or 1e-16 + + self.outputSize = outputSize + self.inputSize = inputSize + + self.v = torch.Tensor(outputSize, inputSize) + self.gradV = torch.Tensor(outputSize, inputSize) + + self.weight = torch.Tensor(outputSize, inputSize) + + self.g = torch.Tensor(outputSize,1) + self.gradG = torch.Tensor(outputSize,1) + + self.norm = torch.Tensor(outputSize,1) + self.scale = torch.Tensor(outputSize,1) + + if bias then + self.bias = torch.Tensor(outputSize) + self.gradBias = torch.Tensor(outputSize) + end + + self:reset() +end + +function LinearWeightNorm:evaluate() + if self.train ~= false then + self:updateWeightMatrix() + end + + parent.evaluate(self) +end + +function LinearWeightNorm:initFromWeight(weight) + weight = weight or self.weight + + self.g:norm(weight,2,2):clamp(self.eps,math.huge) + self.v:copy(weight) + + return self +end + +function LinearWeightNorm.fromLinear(linear) + local module = nn.LinearWeightNorm(linear.weight:size(2), linear.weight:size(1), torch.isTensor(linear.bias)) + module.weight:copy(linear.weight) + module:initFromWeight() + + if linear.bias then + module.bias:copy(linear.bias) + end + + return module +end + +function LinearWeightNorm:toLinear() + self:updateWeightMatrix() + + local module = nn.Linear(self.inputSize, self.outputSize, torch.isTensor(self.bias)) + + module.weight:copy(self.weight) + if self.bias then + module.bias:copy(self.bias) + end + + return module +end + +function LinearWeightNorm:parameters() + if self.bias then + return {self.v, self.g, self.bias}, {self.gradV, self.gradG, self.gradBias} + else + return {self.v, self.g}, {self.gradV, self.gradG} + end +end + +function LinearWeightNorm:reset(stdv) + if stdv then + stdv = stdv * math.sqrt(3) + else + stdv = 1 / math.sqrt(self.inputSize) + end + + self.weight:uniform(-stdv,stdv) + self:initFromWeight() + + if self.bias then + self.bias:uniform(-stdv,stdv) + end +end + +function LinearWeightNorm:updateWeightMatrix() + if self.norm:dim() == 0 then self.norm:resizeAs(self.g) end + if self.scale:dim() == 0 then self.scale:resizeAs(self.g) end + if self.weight:dim() == 0 then self.weight:resizeAs(self.v) end + + self.norm:norm(self.v,2,2):clamp(self.eps,math.huge) + self.scale:cdiv(self.g,self.norm) + self.weight:cmul(self.v,self.scale:expandAs(self.v)) +end + +function LinearWeightNorm:updateOutput(input) + if self.train ~= false then + self:updateWeightMatrix() + end + + return parent.updateOutput(self, input) +end + +function LinearWeightNorm:accGradParameters(input, gradOutput, scale) + scale = scale or 1 + if input:dim() == 1 then + self.gradV:addr(scale, gradOutput, input) + if self.bias then self.gradBias:add(scale, gradOutput) end + elseif input:dim() == 2 then + self.gradV:addmm(scale, gradOutput:t(), input) + if self.bias then + -- update the size of addBuffer if the input is not the same size as the one we had in last updateGradInput + self:updateAddBuffer(input) + self.gradBias:addmv(scale, gradOutput:t(), self.addBuffer) + end + end + + local scale = self.scale:expandAs(self.v) + local norm = self.norm:expandAs(self.v) + + self.weight:cmul(self.gradV,self.v):cdiv(norm) + self.gradG:sum(self.weight,2) + + self.gradV:cmul(scale) + + self.weight:cmul(self.v,scale):cdiv(norm) + self.weight:cmul(self.gradG:expandAs(self.weight)) + + self.gradV:add(-1,self.weight) +end + +function LinearWeightNorm:defaultAccUpdateGradParameters(input, gradOutput, lr) + local gradV = self.gradV + local gradG = self.gradG + local gradBias = self.gradBias + + self.gradV = self.v + self.gradG = self.g + self.gradBias = self.bias + + self:accGradParameters(input, gradOutput, -lr) + + self.gradV = gradV + self.gradG = gradG + self.gradBias = gradBias +end + +function LinearWeightNorm:clearState() + nn.utils.clear(self, 'weight', 'norm', 'scale') + return parent.clearState(self) +end + +function LinearWeightNorm:__tostring__() + return torch.type(self) .. + string.format('(%d -> %d)', self.inputSize, self.outputSize) .. + (self.bias == nil and ' without bias' or '') +end
\ No newline at end of file diff --git a/doc/containers.md b/doc/containers.md index 5e404e3..cecf782 100644..100755 --- a/doc/containers.md +++ b/doc/containers.md @@ -346,9 +346,12 @@ mlp = nn.Bottle(nn.Linear(10, 2)) module = nn.WeightNorm(module) ``` -WeightNorm implements the reparametrization presented in [Weight Normalization](https://arxiv.org/pdf/1602.07868v3.pdf), which decouples the length of neural network weight vectors from their direction. The weight vectors `w` is determined instead by parameters `g` and `v` such that `w = g * v / ||v||`, where `||v||` is the euclidean norm of vector v. This container can wrap nn layers with weights. +WeightNorm implements the reparametrization presented in [Weight Normalization](https://arxiv.org/pdf/1602.07868v3.pdf), which decouples the length of neural network weight vectors from their direction. The weight vector `w` is determined instead by parameters `g` and `v` such that `w = g * v / ||v||`, where `||v||` is the euclidean norm of vector `v`. This container can wrap nn layers with weights. + +It accepts a parameter ``outputDim`` that represents the output dimension of the module weight it wraps, which defaults to 1. If the outputDim is not 1 the container will transpose the weight appropriately. If the module weight is not 2D, e.g. in the case of convolutional layers, the container will view the weight into an appropriate 2D shape based on the `outputDim` specified by the user. + +An optimised version of `nn.WeightNorm(nn.Linear(inputDimension, outputDimension))` is available as `nn.LinearWeightNorm(inputDimension, outputDimension, [bias = true])`. This layer occupies less memory and is faster through the use of fewer tensor copy operations, it also stores and updates a dirty flag to avoid unnecessary computation of the weight matrix. -It accepts a parameter ``outputDim`` that represents the output dimension of the module weight it wraps, which defaults to 1. If the outputDim is not 1, the container will transpose the weight appropriately. If the module weight is not 2D, the container will view the weight into an appropriate 2D shape based on the outputDim specified by the user. <a name='nn.DontCast'></a> ## DontCast ## diff --git a/doc/simple.md b/doc/simple.md index 9306edf..2be635c 100644..100755 --- a/doc/simple.md +++ b/doc/simple.md @@ -4,6 +4,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi * Parameterized Modules : * [Linear](#nn.Linear) : a linear transformation ; + * [LinearWeightNorm](#nn.LinearWeightNorm) : a weight normalized linear transformation ; * [SparseLinear](#nn.SparseLinear) : a linear transformation with sparse inputs ; * [IndexLinear](#nn.IndexLinear) : an alternative linear transformation with for sparse inputs and max normalization ; * [Bilinear](#nn.Bilinear) : a bilinear transformation with sparse inputs ; @@ -100,6 +101,19 @@ x = torch.Tensor(10) -- 10 inputs y = module:forward(x) ``` +<a name="nn.LinearWeightNorm"></a> +## LinearWeightNorm ## + +```lua +module = nn.LinearWeightNorm(inputDimension, outputDimension, [bias = true]) +``` + +LinearWeightNorm implements the reparametrization presented in [Weight Normalization](https://arxiv.org/pdf/1602.07868v3.pdf), which decouples the length of neural network weight vectors from their direction. The weight vector `w` is determined instead by parameters `g` and `v` such that `w = g * v / ||v||`, where `||v||` is the euclidean norm of vector `v`. In all other respects this layer behaves like `nn.Linear`. + +To convert between `nn.Linear` and `nn.LinearWeightNorm` you can use the `nn.LinearWeightNorm.fromLinear(linearModule)` and `weightNormModule:toLinear()` functions. + +Other layer types can make use of weight normalization through the [nn.WeightNorm](https://github.com/torch/nn/blob/master/doc/containers.md#nn.WeightNorm) container. + <a name="nn.SparseLinear"></a> ## SparseLinear ## @@ -24,6 +24,7 @@ require('nn.NaN') require('nn.Profile') require('nn.Linear') +require('nn.LinearWeightNorm') require('nn.Bilinear') require('nn.PartialLinear') require('nn.SparseLinear') @@ -177,6 +177,54 @@ function nntest.WeightNorm() mytester:assertTensorEq(out, outr) end +function nntest.LinearWeightNorm() + local input = torch.rand(10, 5) + local model = nn.LinearWeightNorm(5, 20) + + -- check gradient + local err = nn.Jacobian.testJacobianParameters(model, input, model.bias, model.gradBias) + mytester:assert(err < precision, 'bias') + err = nn.Jacobian.testJacobianParameters(model, input, model.g, model.gradG) + mytester:assert(err < precision, 'g') + err = nn.Jacobian.testJacobianParameters(model, input, model.v, model.gradV) + mytester:assert(err < precision, 'v') + + -- check conversion functions + local linear = nn.Linear(5,20) + local wnFromLin = nn.LinearWeightNorm.fromLinear(linear) + local linFromWn = wnFromLin:toLinear() + + local linOut = linear:forward(input) + local wnOut = wnFromLin:forward(input) + local linFromWnOut = linFromWn:forward(input) + + mytester:assertTensorEq(linOut, wnOut, precision, "outputs are not equivalent") + mytester:assertTensorEq(wnOut, linFromWnOut, precision, "outputs are not equivalent") + + -- check conversion with nobias + linear = nn.Linear(5,20,false) + wnFromLin = nn.LinearWeightNorm.fromLinear(linear) + linFromWn = wnFromLin:toLinear() + + linOut = linear:forward(input) + wnOut = wnFromLin:forward(input) + linFromWnOut = linFromWn:forward(input) + + mytester:assertTensorEq(linear.weight, wnFromLin.weight, precision, "weights are not equivalent") + mytester:assert(not wnFromLin.bias) + mytester:assert(not linear.bias) + mytester:assertTensorEq(linOut, wnOut, precision, "outputs are not equivalent") + mytester:assertTensorEq(wnOut, linFromWnOut, precision, "outputs are not equivalent") + + -- check gradient with nobias + model = wnFromLin + + err = nn.Jacobian.testJacobianParameters(model, input, model.g, model.gradG) + mytester:assert(err < precision, 'g') + err = nn.Jacobian.testJacobianParameters(model, input, model.v, model.gradV) + mytester:assert(err < precision, 'v') +end + function nntest.CAdd() local function testBackwardPass(module, input, params, dparams) local err = jac.testJacobian(module,input) |